SKLearner Home | About | Contact | Examples

Scikit-Learn Get GridSearchCV "refit_time_" Attribute

The GridSearchCV class in scikit-learn performs hyperparameter tuning by fitting a specified model with different combinations of hyperparameters and evaluating their performance using cross-validation. The best performing model is then refitted on the entire training set.

The refit_time_ attribute of the GridSearchCV object stores the time (in seconds) it took to refit the best model on the entire training set after the cross-validation process. Accessing this attribute can provide insights into the computational cost of refitting the model with the optimal hyperparameters.

By examining the refit_time_, you can gauge the efficiency of the model fitting process and make informed decisions about the trade-off between model complexity and training time. This information can be particularly useful when deploying models in production environments where training time is a critical factor.

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
import time

# Generate a synthetic regression dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)

# Create a RandomForestRegressor estimator
rf = RandomForestRegressor(random_state=42)

# Define the parameter grid
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5, 10]
}

# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)

# Fit the GridSearchCV object
start_time = time.time()
grid_search.fit(X, y)
end_time = time.time()

# Access the refit_time_ attribute
refit_time = grid_search.refit_time_

# Print the refit time
print(f"Refit time: {refit_time:.2f} seconds")
print(f"Total time: {end_time - start_time:.2f} seconds")

Running the example gives an output like:

Refit time: 1.29 seconds
Total time: 70.34 seconds

The key steps in this example are:

  1. Preparing a synthetic regression dataset using make_regression.
  2. Defining the RandomForestRegressor estimator and the parameter grid for grid search.
  3. Creating a GridSearchCV object with the estimator, parameter grid, and cross-validation strategy.
  4. Measuring the start time before fitting the GridSearchCV object.
  5. Fitting the GridSearchCV object on the synthetic dataset.
  6. Measuring the end time after the fitting process is complete.
  7. Accessing the refit_time_ attribute from the fitted GridSearchCV object.
  8. Printing the refit time and the total time taken for the grid search process.


See Also