SKLearner Home | About | Contact | Examples

Scikit-Learn Get GridSearchCV "best_estimator_" Attribute

The GridSearchCV class in scikit-learn is a powerful tool for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and fits a specified model for each combination of those values using cross-validation.

After the grid search process is complete, the GridSearchCV object stores the best estimator that achieved the highest score during cross-validation. This best estimator is accessible through the best_estimator_ attribute of the GridSearchCV object.

Accessing the best_estimator_ attribute allows you to retrieve the best performing model with the optimal hyperparameter configuration. You can then use this best estimator for further predictions, model evaluation, or as a final model for deployment.

from sklearn.datasets import make_classification
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV

# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)

# Create an SVC estimator
svc = SVC(random_state=42)

# Define the parameter grid
param_grid = {
    'C': [0.1, 1, 10],
    'kernel': ['linear', 'rbf'],
    'gamma': ['scale', 'auto']
}

# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5)

# Fit the GridSearchCV object
grid_search.fit(X, y)

# Access the best estimator
best_estimator = grid_search.best_estimator_

# Print the best estimator
print("Best Estimator:")
print(best_estimator)

# Use the best estimator for predictions
y_pred = best_estimator.predict(X)

# Print the accuracy of the best estimator
accuracy = best_estimator.score(X, y)
print(f"Accuracy: {accuracy:.2f}")

Running the example gives an output like:

Best Estimator:
SVC(C=10, kernel='linear', random_state=42)
Accuracy: 0.88

The key steps in this example are:

  1. Preparing a synthetic classification dataset using make_classification to use for grid search.
  2. Defining the SVC estimator and the parameter grid with hyperparameters to tune.
  3. Creating a GridSearchCV object with the estimator, parameter grid, and cross-validation strategy.
  4. Fitting the GridSearchCV object on the synthetic dataset.
  5. Accessing the best_estimator_ attribute from the fitted GridSearchCV object.
  6. Using the best estimator for further predictions and evaluating its accuracy.


See Also