cross_val_predict() is a useful function in scikit-learn that allows you to visualize and analyze the output of cross-validation predictions.
It provides a way to generate prediction results for each data point using cross-validation, helping to better understand model performance and potential overfitting.
The key hyperparameters include the estimator (the model to be used), X and y (features and target variable), and cv (the number of cross-validation splits).
This function is appropriate for both classification and regression problems.
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_predict
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
# create model
model = RandomForestClassifier()
# use cross_val_predict
y_pred = cross_val_predict(model, X, y, cv=5)
# evaluate predictions
acc = accuracy_score(y, y_pred)
print('Accuracy: %.3f' % acc)
Running the example gives an output like:
Accuracy: 0.950
The steps are as follows:
Generate a synthetic binary classification dataset using
make_classification(). This creates a dataset with a specified number of samples (n_samples), features (n_features), and a fixed random seed (random_state) for reproducibility.Instantiate a
RandomForestClassifiermodel with default hyperparameters.Use
cross_val_predict()to generate cross-validated predictions. This function performs cross-validation on the model and returns the predictions for each data point.Evaluate the performance of the predictions using the accuracy score metric. The accuracy is calculated by comparing the cross-validated predictions (
y_pred) to the actual values (y).