cross_val_predict()
is a useful function in scikit-learn that allows you to visualize and analyze the output of cross-validation predictions.
It provides a way to generate prediction results for each data point using cross-validation, helping to better understand model performance and potential overfitting.
The key hyperparameters include the estimator
(the model to be used), X
and y
(features and target variable), and cv
(the number of cross-validation splits).
This function is appropriate for both classification and regression problems.
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_predict
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
# create model
model = RandomForestClassifier()
# use cross_val_predict
y_pred = cross_val_predict(model, X, y, cv=5)
# evaluate predictions
acc = accuracy_score(y, y_pred)
print('Accuracy: %.3f' % acc)
Running the example gives an output like:
Accuracy: 0.950
The steps are as follows:
Generate a synthetic binary classification dataset using
make_classification()
. This creates a dataset with a specified number of samples (n_samples
), features (n_features
), and a fixed random seed (random_state
) for reproducibility.Instantiate a
RandomForestClassifier
model with default hyperparameters.Use
cross_val_predict()
to generate cross-validated predictions. This function performs cross-validation on the model and returns the predictions for each data point.Evaluate the performance of the predictions using the accuracy score metric. The accuracy is calculated by comparing the cross-validated predictions (
y_pred
) to the actual values (y
).