SKLearner Home | About | Contact | Examples

Scikit-Learn GroupKFold Data Splitting

GroupKFold is a cross-validation method that ensures that the same group is not represented in both the training and test sets.

This is useful for situations where data points are not independent, such as repeated measures or grouped observations.

The key hyperparameters of GroupKFold include n_splits, which specifies the number of folds.

This algorithm is appropriate for classification and regression problems where data points are grouped.

from sklearn.datasets import make_classification
from sklearn.model_selection import GroupKFold, train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
groups = np.array([i // 10 for i in range(100)])  # create 10 groups

# configure the cross-validation procedure
cv = GroupKFold(n_splits=5)

# prepare the cross-validation split
for train_ix, test_ix in cv.split(X, y, groups):
    # select rows
    X_train, X_test = X[train_ix], X[test_ix]
    y_train, y_test = y[train_ix], y[test_ix]
    # create model
    model = LogisticRegression()
    # fit model
    model.fit(X_train, y_train)
    # evaluate model
    yhat = model.predict(X_test)
    acc = accuracy_score(y_test, yhat)
    print('Accuracy: %.3f' % acc)

Running the example gives an output like:

Accuracy: 0.900
Accuracy: 0.950
Accuracy: 0.950
Accuracy: 0.950
Accuracy: 1.000

The steps are as follows:

  1. First, a synthetic binary classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility. The dataset includes 10 groups to simulate repeated measures or grouped observations.

  2. The GroupKFold cross-validation procedure is configured with 5 splits. This ensures that the same group is not represented in both training and test sets.

  3. The dataset is split into training and test sets based on groups using the split() method of GroupKFold.

  4. For each split, a LogisticRegression model is instantiated and fit on the training data using the fit() method.

  5. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  6. The accuracy for each fold is printed, demonstrating the use of GroupKFold to ensure that data points from the same group are not split between training and test sets, which is crucial for maintaining the validity of the evaluation.

This example shows how to use GroupKFold for cross-validation in scenarios where data points are grouped, ensuring that the same group is not represented in both training and test sets, which is crucial for maintaining the validity of the evaluation.



See Also