SKLearner Home | About | Contact | Examples

Scikit-Learn PredefinedSplit Data Splitting

PredefinedSplit allows for custom train/test splits in cross-validation, which is useful when specific data partitions are needed for evaluation.

The key parameter is test_fold, which defines the fold index for each sample. This method is applicable to any machine learning problem requiring custom cross-validation splits, such as classification and regression.

from sklearn.model_selection import PredefinedSplit, cross_val_score
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
import numpy as np

# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)

# create custom test fold splits: -1 indicates training data
test_fold = np.zeros(100)
test_fold[:20] = -1  # first 20 samples are training
test_fold[20:] = 0  # remaining samples are test

# create PredefinedSplit
ps = PredefinedSplit(test_fold)

# create model
model = LogisticRegression()

# evaluate model
scores = cross_val_score(model, X, y, cv=ps)
print('Accuracy: %.3f' % np.mean(scores))

# make a prediction using the last training fold
model.fit(X[:20], y[:20])
row = X[20].reshape(1, -1)
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.950
Predicted: 0
  1. First, a synthetic binary classification dataset is generated using the make_classification() function, creating a dataset with a specified number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility.
  2. A custom test_fold array is created, where -1 indicates training data, and 0 indicates test data.
  3. The PredefinedSplit class is instantiated with the test_fold array, defining the custom train/test splits.
  4. A LogisticRegression model is created and evaluated using cross_val_score() with the PredefinedSplit cross-validator.
  5. The model is fit on the training portion of the data, and a prediction is made using the test portion.

This example demonstrates how to set up and use PredefinedSplit for custom cross-validation in scikit-learn, enabling precise control over data partitioning in model evaluation.



See Also