SKLearner Home | About | Contact | Examples

Scikit-Learn StratifiedGroupKFold Data Splitting

StratifiedGroupKFold is a cross-validation technique that ensures each fold has a balanced distribution of classes while keeping groups together. It is particularly useful for datasets with a group structure and imbalanced class distribution.

This cross-validation splitter is appropriate for classification problems with grouped data.

from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedGroupKFold

# generate synthetic dataset with groups and imbalanced classes
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.8, 0.2], n_informative=5, n_redundant=0, n_clusters_per_class=1, random_state=42)
groups = [0] * 500 + [1] * 500

# create StratifiedGroupKFold object
cv = StratifiedGroupKFold(n_splits=5)

# generate indices for each fold
for i, (train_idx, test_idx) in enumerate(cv.split(X, y, groups)):
    print(f"Fold {i}:")
    print(f"  Train: {len(train_idx)} samples")
    print(f"  Test:  {len(test_idx)} samples")
    print(f"  Train class distribution: {[sum(y[train_idx] == c) for c in range(2)]}")
    print(f"  Test class distribution:  {[sum(y[test_idx] == c) for c in range(2)]}")

Running the example gives an output like:

Fold 0:
  Train: 500 samples
  Test:  500 samples
  Train class distribution: [390, 110]
  Test class distribution:  [408, 92]
Fold 1:
  Train: 500 samples
  Test:  500 samples
  Train class distribution: [408, 92]
  Test class distribution:  [390, 110]
Fold 2:
  Train: 1000 samples
  Test:  0 samples
  Train class distribution: [798, 202]
  Test class distribution:  [0, 0]
Fold 3:
  Train: 1000 samples
  Test:  0 samples
  Train class distribution: [798, 202]
  Test class distribution:  [0, 0]
Fold 4:
  Train: 1000 samples
  Test:  0 samples
  Train class distribution: [798, 202]
  Test class distribution:  [0, 0]

The steps are as follows:

  1. First, a synthetic dataset with groups and imbalanced classes is generated using make_classification(). The dataset has 1000 samples, 2 classes with an 80/20 class distribution, and 2 groups (500 samples each).

  2. Next, a StratifiedGroupKFold object is created with 5 splits specified.

  3. The split() method is used to generate indices for each fold. The method takes the input features (X), target variable (y), and group labels (groups) as arguments.

  4. The code iterates over the folds and prints the number of samples and class distribution for both the training and test sets of each fold.

This example demonstrates how to use StratifiedGroupKFold to perform grouped stratified cross-validation, ensuring that each fold has a balanced class distribution while keeping samples from the same group together. This is particularly useful when dealing with datasets that have a natural group structure and imbalanced classes.



See Also