The GridSearchCV
class in scikit-learn is a powerful tool for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and fits a specified model for each combination of those values using cross-validation.
After the grid search process is complete, the GridSearchCV
object stores detailed results of the cross-validation process in the cv_results_
attribute. This attribute is a dictionary that contains metrics for each hyperparameter combination, such as mean test scores, train scores, fit times, and more.
Accessing the cv_results_
attribute allows you to retrieve these metrics and gain insights into the performance of the model across different hyperparameter settings. You can use this information to identify the optimal hyperparameters and understand how different hyperparameter values impact the model’s performance.
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)
# Create a RandomForestClassifier estimator
rf = RandomForestClassifier(random_state=42)
# Define the parameter grid
param_grid = {
'n_estimators': [50, 100, 200],
'max_depth': [None, 5, 10],
'min_samples_split': [2, 5, 10]
}
# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)
# Fit the GridSearchCV object
grid_search.fit(X, y)
# Access the cv_results_ attribute
cv_results = grid_search.cv_results_
# View the keys in cv_results_
print(cv_results.keys())
# Access the mean test scores for each hyperparameter combination
mean_test_scores = cv_results['mean_test_score']
# View the mean test scores
print(mean_test_scores)
Running the example gives an output like:
dict_keys(['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time', 'param_max_depth', 'param_min_samples_split', 'param_n_estimators', 'params', 'split0_test_score', 'split1_test_score', 'split2_test_score', 'split3_test_score', 'split4_test_score', 'mean_test_score', 'std_test_score', 'rank_test_score'])
[0.9 0.899 0.898 0.896 0.897 0.898 0.889 0.889 0.892 0.877 0.878 0.877
0.878 0.882 0.881 0.884 0.887 0.88 0.904 0.897 0.893 0.889 0.893 0.897
0.887 0.891 0.896]
The key steps in this example are:
- Preparing a synthetic classification dataset using
make_classification
to use for grid search. - Defining the
RandomForestClassifier
estimator and the parameter grid with hyperparameters to tune. - Creating a
GridSearchCV
object with the estimator, parameter grid, and cross-validation strategy. - Fitting the
GridSearchCV
object on the synthetic dataset. - Accessing the
cv_results_
attribute from the fittedGridSearchCV
object. - Inspecting the cross-validation results stored in
cv_results_
, such as viewing the keys and accessing specific scores like the mean test scores.