SKLearner Home | About | Contact | Examples

Scikit-Learn "StratifiedKFold" versus "KFold"

In scikit-learn, KFold and StratifiedKFold are commonly used for cross-validation, which is essential for evaluating model performance. While both methods aim to provide a robust estimate of model performance, they have different approaches to handling data splits.

KFold is the simplest form of cross-validation, splitting the dataset into k folds of equal size. Its key hyperparameters include n_splits (number of folds) and shuffle (whether to shuffle the data before splitting). However, KFold does not consider the distribution of classes in each fold, which can be problematic for imbalanced datasets.

On the other hand, StratifiedKFold ensures that each fold maintains the same class distribution as the original dataset. Its key hyperparameters are similar to KFold, including n_splits and shuffle. This method is particularly useful for imbalanced datasets, as it ensures each fold is representative of the overall class distribution.

The main difference between KFold and StratifiedKFold lies in how they handle class distributions. While KFold randomly splits the data, potentially leading to imbalanced class distributions in some folds, StratifiedKFold maintains the class distribution, providing more reliable performance estimates for imbalanced datasets.

KFold is suitable for balanced datasets or when class distribution is not a concern. StratifiedKFold is preferred for imbalanced datasets, ensuring each fold is representative of the overall class distribution, leading to more reliable model evaluation.

from sklearn.datasets import make_classification
from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Generate synthetic imbalanced binary classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1], random_state=42)

# Initialize logistic regression model
model = LogisticRegression()

# KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
kf_scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy')
print(f"KFold accuracy scores: {kf_scores}")
print(f"Mean KFold accuracy: {kf_scores.mean():.3f}")

# StratifiedKFold cross-validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
skf_scores = cross_val_score(model, X, y, cv=skf, scoring='accuracy')
print(f"StratifiedKFold accuracy scores: {skf_scores}")
print(f"Mean StratifiedKFold accuracy: {skf_scores.mean():.3f}")

Running the example gives an output like:

KFold accuracy scores: [0.895 0.95  0.935 0.925 0.945]
Mean KFold accuracy: 0.930
StratifiedKFold accuracy scores: [0.905 0.925 0.93  0.935 0.925]
Mean StratifiedKFold accuracy: 0.924

The steps are as follows:

  1. Generate a synthetic imbalanced binary classification dataset using make_classification.
  2. Initialize a LogisticRegression model.
  3. Use KFold with 5 splits, shuffle the data, and evaluate model performance using cross_val_score.
  4. Use StratifiedKFold with 5 splits, shuffle the data, and evaluate model performance using cross_val_score.
  5. Compare the accuracy scores from both cross-validation methods and discuss the differences in performance.


See Also