The GridSearchCV
class in scikit-learn performs hyperparameter tuning by fitting a specified model with different combinations of hyperparameters and evaluating their performance using cross-validation. The best performing model is then refitted on the entire training set.
The refit_time_
attribute of the GridSearchCV
object stores the time (in seconds) it took to refit the best model on the entire training set after the cross-validation process. Accessing this attribute can provide insights into the computational cost of refitting the model with the optimal hyperparameters.
By examining the refit_time_
, you can gauge the efficiency of the model fitting process and make informed decisions about the trade-off between model complexity and training time. This information can be particularly useful when deploying models in production environments where training time is a critical factor.
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
import time
# Generate a synthetic regression dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)
# Create a RandomForestRegressor estimator
rf = RandomForestRegressor(random_state=42)
# Define the parameter grid
param_grid = {
'n_estimators': [50, 100, 200],
'max_depth': [None, 5, 10],
'min_samples_split': [2, 5, 10]
}
# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)
# Fit the GridSearchCV object
start_time = time.time()
grid_search.fit(X, y)
end_time = time.time()
# Access the refit_time_ attribute
refit_time = grid_search.refit_time_
# Print the refit time
print(f"Refit time: {refit_time:.2f} seconds")
print(f"Total time: {end_time - start_time:.2f} seconds")
Running the example gives an output like:
Refit time: 1.29 seconds
Total time: 70.34 seconds
The key steps in this example are:
- Preparing a synthetic regression dataset using
make_regression
. - Defining the
RandomForestRegressor
estimator and the parameter grid for grid search. - Creating a
GridSearchCV
object with the estimator, parameter grid, and cross-validation strategy. - Measuring the start time before fitting the
GridSearchCV
object. - Fitting the
GridSearchCV
object on the synthetic dataset. - Measuring the end time after the fitting process is complete.
- Accessing the
refit_time_
attribute from the fittedGridSearchCV
object. - Printing the refit time and the total time taken for the grid search process.