Scikit-learnでクラス数を考慮してデータを分割したい
以下のように、scikit-learnの関数train_test_splitを用いると、
dataset_train, dataset_test = train_test_split(dataset, train_size=0.8)
データセットを訓練データとテストデータに分割はしてくれるのですが、
クラス数が多い(例えば100クラス)場合だと、
訓練データとテストデータの各々のクラス数が異なる時があります。
例えば訓練データに含まれるクラス数は100である一方、
テストデータのそれは98となってしまうことがあります。
train_test_splitでは、ランダムシャッフルしてsplitしているだけなので、
クラスに含まれるデータ数がアンバランスな場合、このようなことが起きると思われます。
クラス数もちゃんと保つように、データを分割するには、
どのようにすれば良いでしょうか?
よろしくお願いします。