SKLearner Home | About | Contact | Examples

Scikit-Learn "RepeatedStratifiedKFold" versus "StratifiedKFold"

When validating machine learning models, cross-validation techniques ensure robust performance estimates. StratifiedKFold and RepeatedStratifiedKFold are two such methods in scikit-learn.

StratifiedKFold ensures that each fold of the cross-validation process maintains the same class distribution as the original dataset. Key hyperparameters include n_splits (number of folds) and shuffle (whether to shuffle the data before splitting).

RepeatedStratifiedKFold extends StratifiedKFold by repeating the stratified k-fold process multiple times with different data shuffles, enhancing the robustness of the validation. Its key hyperparameters include n_splits (number of folds), n_repeats (number of repetitions), and random_state (seed for random number generation).

The main difference is that RepeatedStratifiedKFold provides a more comprehensive validation by repeating the k-fold process, thus yielding a better estimate of model performance variance. However, it is more computationally intensive than StratifiedKFold.

StratifiedKFold is ideal for quick cross-validation when computational resources are limited. RepeatedStratifiedKFold is preferred when a thorough validation is needed, despite its higher computational cost.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split, StratifiedKFold, RepeatedStratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

# Generate synthetic binary classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.8, 0.2], random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Evaluate with StratifiedKFold
skf = StratifiedKFold(n_splits=5)
accuracy_scores_skf = []
f1_scores_skf = []

for train_index, val_index in skf.split(X_train, y_train):
    X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]

    model = LogisticRegression(random_state=42)
    model.fit(X_train_fold, y_train_fold)
    y_pred = model.predict(X_val_fold)

    accuracy_scores_skf.append(accuracy_score(y_val_fold, y_pred))
    f1_scores_skf.append(f1_score(y_val_fold, y_pred))

print(f"StratifiedKFold mean accuracy: {np.mean(accuracy_scores_skf):.3f}")
print(f"StratifiedKFold mean F1 score: {np.mean(f1_scores_skf):.3f}")

# Evaluate with RepeatedStratifiedKFold
rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=10, random_state=42)
accuracy_scores_rskf = []
f1_scores_rskf = []

for train_index, val_index in rskf.split(X_train, y_train):
    X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]

    model = LogisticRegression(random_state=42)
    model.fit(X_train_fold, y_train_fold)
    y_pred = model.predict(X_val_fold)

    accuracy_scores_rskf.append(accuracy_score(y_val_fold, y_pred))
    f1_scores_rskf.append(f1_score(y_val_fold, y_pred))

print(f"RepeatedStratifiedKFold mean accuracy: {np.mean(accuracy_scores_rskf):.3f}")
print(f"RepeatedStratifiedKFold mean F1 score: {np.mean(f1_scores_rskf):.3f}")

Running the example gives an output like:

StratifiedKFold mean accuracy: 0.922
StratifiedKFold mean F1 score: 0.793
RepeatedStratifiedKFold mean accuracy: 0.916
RepeatedStratifiedKFold mean F1 score: 0.775
  1. Generate a synthetic binary classification dataset using make_classification.
  2. Split the data into training and test sets using train_test_split.
  3. Instantiate StratifiedKFold with specified number of splits, fit the classifier, and evaluate its performance.
  4. Instantiate RepeatedStratifiedKFold with specified number of repeats and splits, fit the classifier, and evaluate its performance.
  5. Compare the test set performance (accuracy and F1 score) of both cross-validation strategies and print the results.


See Also