The GridSearchCV
class in scikit-learn is a powerful tool for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and fits a specified model for each combination of those values using cross-validation.
After the grid search process is complete, the GridSearchCV
object stores the best estimator that achieved the highest score during cross-validation. This best estimator is accessible through the best_estimator_
attribute of the GridSearchCV
object.
Accessing the best_estimator_
attribute allows you to retrieve the best performing model with the optimal hyperparameter configuration. You can then use this best estimator for further predictions, model evaluation, or as a final model for deployment.
from sklearn.datasets import make_classification
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)
# Create an SVC estimator
svc = SVC(random_state=42)
# Define the parameter grid
param_grid = {
'C': [0.1, 1, 10],
'kernel': ['linear', 'rbf'],
'gamma': ['scale', 'auto']
}
# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5)
# Fit the GridSearchCV object
grid_search.fit(X, y)
# Access the best estimator
best_estimator = grid_search.best_estimator_
# Print the best estimator
print("Best Estimator:")
print(best_estimator)
# Use the best estimator for predictions
y_pred = best_estimator.predict(X)
# Print the accuracy of the best estimator
accuracy = best_estimator.score(X, y)
print(f"Accuracy: {accuracy:.2f}")
Running the example gives an output like:
Best Estimator:
SVC(C=10, kernel='linear', random_state=42)
Accuracy: 0.88
The key steps in this example are:
- Preparing a synthetic classification dataset using
make_classification
to use for grid search. - Defining the
SVC
estimator and the parameter grid with hyperparameters to tune. - Creating a
GridSearchCV
object with the estimator, parameter grid, and cross-validation strategy. - Fitting the
GridSearchCV
object on the synthetic dataset. - Accessing the
best_estimator_
attribute from the fittedGridSearchCV
object. - Using the best estimator for further predictions and evaluating its accuracy.