SKLearner Home | About | Contact | Examples

Scikit-Learn Configure RandomizedSearchCV "refit" Parameter

The refit parameter in RandomizedSearchCV determines whether the best model is refit on the whole dataset after the search. Random search is a hyperparameter optimization method that tries random combinations of parameters to find the best performing model.

The refit parameter can be set to True, False, or a string (the name of a scoring parameter).

The default value is True, which refits the best estimator using the whole dataset.

Setting refit to False skips the refitting step, which can save time when refitting is not needed. Using a custom scoring function (e.g., 'accuracy') allows refitting based on a specific metric.

As a heuristic, set refit to False if refitting is not required, or use a custom scoring function for specific evaluation criteria.

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint
import time

# Generate a synthetic binary classification dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, n_redundant=5, random_state=42)

# Define a parameter distribution for RandomForestClassifier hyperparameters
param_dist = {'n_estimators': randint(10, 100),
              'max_depth': [None, 5, 10],
              'min_samples_split': randint(2, 10)}

# Create a base RandomForestClassifier model
rf = RandomForestClassifier(random_state=42)

# List of refit values to test
refit_values = [True, False, 'accuracy']

for refit in refit_values:
    start_time = time.perf_counter()

    # Run RandomizedSearchCV with the current refit value
    search = RandomizedSearchCV(rf, param_dist, n_iter=10, cv=5, refit=refit, random_state=42)
    search.fit(X, y)

    end_time = time.perf_counter()
    execution_time = end_time - start_time

    if refit:
        print(f"Best score for refit={refit}: {search.best_score_:.3f}")
    print(f"Execution time for refit={refit}: {execution_time:.2f} seconds")
    print()

Running the example gives an output like:

Best score for refit=True: 0.939
Execution time for refit=True: 5.94 seconds

Execution time for refit=False: 5.83 seconds

Best score for refit=accuracy: 0.939
Execution time for refit=accuracy: 5.85 seconds

The steps are as follows:

  1. Generate a synthetic binary classification dataset using make_classification() from scikit-learn.
  2. Define a parameter distribution dictionary param_dist for RandomForestClassifier hyperparameters using randint for random integer sampling.
  3. Create a base RandomForestClassifier model rf.
  4. Iterate over different refit values (True, False, ‘accuracy’).
  5. For each refit value:
    • Record the start time using time.perf_counter().
    • Run RandomizedSearchCV with 10 iterations and 5-fold cross-validation.
    • Record the end time and calculate the execution time.
    • Print the best score (if refit is not False) and execution time for the current refit value.


See Also