StratifiedGroupKFold
is a cross-validation technique that ensures each fold has a balanced distribution of classes while keeping groups together. It is particularly useful for datasets with a group structure and imbalanced class distribution.
This cross-validation splitter is appropriate for classification problems with grouped data.
from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedGroupKFold
# generate synthetic dataset with groups and imbalanced classes
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.8, 0.2], n_informative=5, n_redundant=0, n_clusters_per_class=1, random_state=42)
groups = [0] * 500 + [1] * 500
# create StratifiedGroupKFold object
cv = StratifiedGroupKFold(n_splits=5)
# generate indices for each fold
for i, (train_idx, test_idx) in enumerate(cv.split(X, y, groups)):
print(f"Fold {i}:")
print(f" Train: {len(train_idx)} samples")
print(f" Test: {len(test_idx)} samples")
print(f" Train class distribution: {[sum(y[train_idx] == c) for c in range(2)]}")
print(f" Test class distribution: {[sum(y[test_idx] == c) for c in range(2)]}")
Running the example gives an output like:
Fold 0:
Train: 500 samples
Test: 500 samples
Train class distribution: [390, 110]
Test class distribution: [408, 92]
Fold 1:
Train: 500 samples
Test: 500 samples
Train class distribution: [408, 92]
Test class distribution: [390, 110]
Fold 2:
Train: 1000 samples
Test: 0 samples
Train class distribution: [798, 202]
Test class distribution: [0, 0]
Fold 3:
Train: 1000 samples
Test: 0 samples
Train class distribution: [798, 202]
Test class distribution: [0, 0]
Fold 4:
Train: 1000 samples
Test: 0 samples
Train class distribution: [798, 202]
Test class distribution: [0, 0]
The steps are as follows:
First, a synthetic dataset with groups and imbalanced classes is generated using
make_classification()
. The dataset has 1000 samples, 2 classes with an 80/20 class distribution, and 2 groups (500 samples each).Next, a
StratifiedGroupKFold
object is created with 5 splits specified.The
split()
method is used to generate indices for each fold. The method takes the input features (X
), target variable (y
), and group labels (groups
) as arguments.The code iterates over the folds and prints the number of samples and class distribution for both the training and test sets of each fold.
This example demonstrates how to use StratifiedGroupKFold
to perform grouped stratified cross-validation, ensuring that each fold has a balanced class distribution while keeping samples from the same group together. This is particularly useful when dealing with datasets that have a natural group structure and imbalanced classes.