SKLearner Home | About | Contact | Examples

Scikit-Learn StratifiedKFold Data Splitting

StratifiedKFold is a variation of k-fold cross-validation that preserves the class distribution in each fold, making it suitable for classification problems.

The key hyperparameter is n_splits, which determines the number of folds to create.

from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedKFold

# generate multi-class classification dataset
X, y = make_classification(n_samples=100, n_classes=3, random_state=1)

# create StratifiedKFold object
cv = StratifiedKFold(n_splits=3, random_state=1, shuffle=True)

# enumerate the splits
for train_index, test_index in cv.split(X, y):
	print("Train: %s | test: %s" % (train_index, test_index))
	print("Train size: %d | test size: %d" % (len(train_index), len(test_index)))
	train_y, test_y = y[train_index], y[test_index]
	train_dist = {k:len(train_y[train_y==k]) for k in set(train_y)}
	test_dist = {k:len(test_y[test_y==k]) for k in set(test_y)}
	print("Train distribution: %s" % train_dist)
	print("Test distribution: %s" % test_dist)

Running the example gives an output like:

Train: [ 0  1  3  4  5  7  8  9 10 11 12 13 15 18 19 20 22 23 24 25 26 27 28 29
 30 31 32 33 34 35 36 37 39 40 41 45 46 47 49 50 52 53 54 55 58 62 63 66
 67 68 69 72 74 76 77 78 80 81 87 88 89 91 94 96 97 98] | test: [ 2  6 14 16 17 21 38 42 43 44 48 51 56 57 59 60 61 64 65 70 71 73 75 79
 82 83 84 85 86 90 92 93 95 99]
Train size: 66 | test size: 34
Train distribution: {0: 22, 1: 22, 2: 22}
Test distribution: {0: 12, 1: 12, 2: 10}
Train: [ 2  4  6  7  8  9 10 12 14 15 16 17 19 21 22 24 27 28 29 30 32 34 38 39
 42 43 44 46 47 48 50 51 52 54 56 57 58 59 60 61 62 64 65 67 70 71 73 74
 75 76 79 81 82 83 84 85 86 89 90 91 92 93 94 95 97 98 99] | test: [ 0  1  3  5 11 13 18 20 23 25 26 31 33 35 36 37 40 41 45 49 53 55 63 66
 68 69 72 77 78 80 87 88 96]
Train size: 67 | test size: 33
Train distribution: {0: 23, 1: 23, 2: 21}
Test distribution: {0: 11, 1: 11, 2: 11}
Train: [ 0  1  2  3  5  6 11 13 14 16 17 18 20 21 23 25 26 31 33 35 36 37 38 40
 41 42 43 44 45 48 49 51 53 55 56 57 59 60 61 63 64 65 66 68 69 70 71 72
 73 75 77 78 79 80 82 83 84 85 86 87 88 90 92 93 95 96 99] | test: [ 4  7  8  9 10 12 15 19 22 24 27 28 29 30 32 34 39 46 47 50 52 54 58 62
 67 74 76 81 89 91 94 97 98]
Train size: 67 | test size: 33
Train distribution: {0: 23, 1: 23, 2: 21}
Test distribution: {0: 11, 1: 11, 2: 11}

The steps are as follows:

  1. First, a synthetic multi-class classification dataset is generated using make_classification(), specifying the number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility.

  2. A StratifiedKFold object is created with 3 splits (n_splits), shuffling enabled (shuffle), and a fixed random seed.

  3. The split() method of the StratifiedKFold object is called with the dataset (X) and labels (y). This returns an iterator over the train and test indexes for each fold.

  4. For each fold, the train and test indexes are printed, along with their sizes. The class distributions for the train and test sets are also calculated and printed.

This example demonstrates how to use StratifiedKFold to split a dataset into train and test sets while preserving the class distribution in each fold. This is particularly useful in classification problems where maintaining the class balance across folds is important for getting reliable performance estimates.



See Also