SKLearner Home | About | Contact | Examples

Scikit-Learn KFold Data Splitting

KFold is a cross-validation technique in scikit-learn that splits the dataset into K consecutive folds. Each fold is then used once as a validation while the K-1 remaining folds form the training set. This helps in evaluating the performance of a model more robustly compared to a single train-test split.

KFold has important hyperparameters including n_splits (number of folds), shuffle (whether to shuffle the data before splitting), and random_state (seed for reproducibility).

KFold is suitable for both regression and classification problems.

from sklearn.datasets import make_classification
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)

# prepare KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=1)

# create model
model = LogisticRegression()

accuracies = []

# perform cross-validation
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

    # fit model
    model.fit(X_train, y_train)

    # evaluate model
    yhat = model.predict(X_test)
    acc = accuracy_score(y_test, yhat)
    accuracies.append(acc)
    print('Fold Accuracy: %.3f' % acc)

# summarize performance
print('Mean Accuracy: %.3f' % (sum(accuracies) / len(accuracies)))

Running the example gives an output like:

Fold Accuracy: 0.950
Fold Accuracy: 0.950
Fold Accuracy: 0.950
Fold Accuracy: 0.950
Fold Accuracy: 0.950
Mean Accuracy: 0.950

The steps are as follows:

  1. First, a synthetic binary classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility. The dataset is split into training and test sets using train_test_split().

  2. The KFold class is instantiated with n_splits set to 5, shuffle enabled, and a fixed random_state for reproducibility. This configuration will split the dataset into 5 folds and shuffle the data before splitting.

  3. A LogisticRegression model is instantiated with default hyperparameters. The model is then fit on the training data using the fit() method.

  4. The cross-validation process is performed by iterating over each fold. For each fold, the dataset is split into training and test sets based on the indices provided by KFold.split(). The model is then fit on the training data and evaluated on the test data using the accuracy score metric. The accuracy for each fold is printed.

  5. The mean accuracy across all folds is calculated and printed, providing a robust estimate of the model’s performance.

This example demonstrates how to use KFold for cross-validation in scikit-learn, helping ensure that the model’s evaluation is reliable and not dependent on a single train-test split.



See Also