Evaluating the performance of a machine learning model reliably by splitting data into multiple subsets for training and testing can be challenging. Using cross_validate()
in scikit-learn helps assess model performance using cross-validation, ensuring robust evaluation.
The key parameters of cross_validate
include the estimator
(the model for training), cv
(the cross-validation splitting strategy), and scoring
(the metric(s) to evaluate the model). This function is suitable for classification, regression, and other supervised learning tasks.
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate
from sklearn.metrics import make_scorer, accuracy_score, f1_score
# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
# create model
model = LogisticRegression()
# evaluate model using cross-validation
scoring = {'accuracy': make_scorer(accuracy_score), 'f1': make_scorer(f1_score)}
scores = cross_validate(model, X, y, scoring=scoring, cv=5, return_train_score=True)
# print scores
print('Train Accuracy: %.3f' % scores['train_accuracy'].mean())
print('Test Accuracy: %.3f' % scores['test_accuracy'].mean())
print('Train F1 Score: %.3f' % scores['train_f1'].mean())
print('Test F1 Score: %.3f' % scores['test_f1'].mean())
Running the example gives an output like:
Train Accuracy: 0.950
Test Accuracy: 0.950
Train F1 Score: 0.947
Test F1 Score: 0.948
The steps are as follows:
- A synthetic binary classification dataset is generated using
make_classification()
with specified samples and features. This ensures reproducibility with a fixed random seed. - A
LogisticRegression
model is instantiated with default hyperparameters. - The
cross_validate
function is used to perform 5-fold cross-validation on the model. Thescoring
parameter evaluates both accuracy and F1 score metrics. - The average training and testing scores for accuracy and F1 score are printed out.
This example demonstrates how to use cross_validate
to evaluate a machine learning model’s performance using cross-validation, ensuring reliable model assessment with multiple metrics.