SKLearner Home | About | Contact | Examples

Scikit-Learn GroupShuffleSplit Data Splitting

GroupShuffleSplit is used for splitting data into training and testing sets while ensuring that all members of each group are assigned to either the training set or the testing set, but not both.

This is useful for cases where data is naturally grouped, and the integrity of these groups must be maintained.

The key hyperparameters of GroupShuffleSplit include the n_splits (number of re-shuffling and splitting iterations), test_size (proportion of dataset to include in the test split), and train_size (proportion of dataset to include in the train split).

The algorithm is appropriate for classification, regression, or any supervised learning problem where data is grouped.

from sklearn.datasets import make_classification
from sklearn.model_selection import GroupShuffleSplit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

# generate classification dataset with groups
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
groups = np.array([i // 10 for i in range(100)])  # 10 groups, each with 10 samples

# define the group shuffle split
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=1)
train_idx, test_idx = next(gss.split(X, y, groups=groups))

# split data into train and test sets
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

# create model
model = LogisticRegression()

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 1.000
Predicted: 0

The steps are as follows:

  1. First, a synthetic classification dataset is generated using the make_classification() function. Groups are created such that each group contains 10 samples.
  2. The GroupShuffleSplit class is instantiated to perform a single split with 20% of the data as the test set.
  3. The data is split into training and testing sets based on the indices provided by GroupShuffleSplit.
  4. A LogisticRegression model is created and fit on the training data.
  5. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.
  6. A single prediction can be made by passing a new data sample to the predict() method.

This example demonstrates how to use GroupShuffleSplit to ensure that groups of samples are not split across training and testing sets, which is crucial for certain types of grouped data in machine learning tasks.



See Also