PredefinedSplit
allows for custom train/test splits in cross-validation, which is useful when specific data partitions are needed for evaluation.
The key parameter is test_fold
, which defines the fold index for each sample. This method is applicable to any machine learning problem requiring custom cross-validation splits, such as classification and regression.
from sklearn.model_selection import PredefinedSplit, cross_val_score
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
import numpy as np
# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
# create custom test fold splits: -1 indicates training data
test_fold = np.zeros(100)
test_fold[:20] = -1 # first 20 samples are training
test_fold[20:] = 0 # remaining samples are test
# create PredefinedSplit
ps = PredefinedSplit(test_fold)
# create model
model = LogisticRegression()
# evaluate model
scores = cross_val_score(model, X, y, cv=ps)
print('Accuracy: %.3f' % np.mean(scores))
# make a prediction using the last training fold
model.fit(X[:20], y[:20])
row = X[20].reshape(1, -1)
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])
Running the example gives an output like:
Accuracy: 0.950
Predicted: 0
- First, a synthetic binary classification dataset is generated using the
make_classification()
function, creating a dataset with a specified number of samples (n_samples
), classes (n_classes
), and a fixed random seed (random_state
) for reproducibility. - A custom
test_fold
array is created, where-1
indicates training data, and0
indicates test data. - The
PredefinedSplit
class is instantiated with thetest_fold
array, defining the custom train/test splits. - A
LogisticRegression
model is created and evaluated usingcross_val_score()
with thePredefinedSplit
cross-validator. - The model is fit on the training portion of the data, and a prediction is made using the test portion.
This example demonstrates how to set up and use PredefinedSplit
for custom cross-validation in scikit-learn, enabling precise control over data partitioning in model evaluation.