SKLearner Home | About | Contact | Examples

Scikit-Learn StratifiedShuffleSplit Data Splitting

StratifiedShuffleSplit is a useful cross-validation splitter in scikit-learn for handling imbalanced classification datasets.

It ensures that the proportion of samples for each class is preserved in each train and test fold.

This data splitting strategy is a variation of ShuffleSplit that returns stratified randomized folds, making it ideal for imbalanced datasets.

from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedShuffleSplit

# generate imbalanced classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1], random_state=1)

# configure the split
split = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=1)

# traverse the splits
for train_index, test_index in split.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    train_0, train_1 = len(y_train[y_train==0]), len(y_train[y_train==1])
    test_0, test_1 = len(y_test[y_test==0]), len(y_test[y_test==1])
    print('>Train: 0=%d, 1=%d, Test: 0=%d, 1=%d' % (train_0, train_1, test_0, test_1))

Running the example provides output like:

>Train: 0=716, 1=84, Test: 0=179, 1=21
>Train: 0=716, 1=84, Test: 0=179, 1=21
>Train: 0=716, 1=84, Test: 0=179, 1=21
>Train: 0=716, 1=84, Test: 0=179, 1=21
>Train: 0=716, 1=84, Test: 0=179, 1=21

The steps in this example are:

  1. An imbalanced binary classification dataset is generated using make_classification() with 1000 samples and a class distribution of 90% to 10% by setting the weights argument.

  2. StratifiedShuffleSplit is configured to generate 5 splits (n_splits) with 20% of the data (test_size) reserved for the test set. The random_state is set to ensure reproducible results.

  3. The split() method is called to generate the indices for each train and test fold. The loop iterates over each fold, using these indices to partition X and y into X_train, X_test, y_train, y_test.

  4. For each fold, the number of samples belonging to each class is calculated for both the train and test sets. These counts are printed, confirming that the class proportions are indeed preserved across the splits.

This example showcases how StratifiedShuffleSplit can be used to create balanced train and test sets from an imbalanced dataset, ensuring reliable evaluation metrics and model performance.



See Also