SKLearner Home | About | Contact | Examples

Scikit-Learn RepeatedStratifiedKFold Data Splitting

RepeatedStratifiedKFold is a cross-validation technique that improves the reliability of model performance estimates by repeating the stratified k-fold cross-validation process multiple times with different splits. It ensures each fold has the same proportion of class labels, which is particularly useful for imbalanced datasets.

Key hyperparameters include n_splits (number of folds), n_repeats (number of times to repeat the process), and random_state (for reproducibility).

This technique is appropriate for classification problems, especially those with imbalanced class distributions.

from sklearn.datasets import make_classification
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)

# define model
model = LogisticRegression()

# define evaluation method
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)

# evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)

# summarize performance
print('Accuracy: %.3f (%.3f)' % (scores.mean(), scores.std()))

Running the example gives an output like:

Accuracy: 0.950 (0.076)

The steps are as follows:

  1. First, a synthetic binary classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility.

  2. Next, a LogisticRegression model is instantiated with default hyperparameters.

  3. The RepeatedStratifiedKFold cross-validation method is set up with 10 splits and 3 repeats, ensuring that each fold has the same proportion of class labels and that the process is repeated multiple times for reliable performance estimates.

  4. The performance of the model is evaluated using cross_val_score(), which performs cross-validation and computes accuracy scores. The mean and standard deviation of the accuracy scores are then printed to summarize the model’s performance.

This example demonstrates how to use RepeatedStratifiedKFold to obtain a robust estimate of model accuracy by averaging results over multiple stratified splits and repeats.



See Also