SKLearner Home | About | Contact | Examples

Scikit-Learn Get GridSearchCV "n_splits_" Attribute

The GridSearchCV class in scikit-learn is commonly used for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and performs cross-validation to evaluate the model’s performance for each combination of those values.

The n_splits_ attribute of the GridSearchCV object determines the number of splits used in the cross-validation process. By default, GridSearchCV uses the value of the cv parameter to determine the number of splits. However, you can explicitly set the cv parameter to control the number of splits used.

Accessing the n_splits_ attribute allows you to verify the actual number of splits used in the grid search process. This can be useful for understanding the cross-validation strategy and ensuring that the desired number of splits is being used.

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)

# Create a RandomForestClassifier estimator
rf = RandomForestClassifier(random_state=42)

# Define the parameter grid
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 5, 10],
    'min_samples_split': [2, 5, 10]
}

# Create a GridSearchCV object with a specified cv value
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)

# Fit the GridSearchCV object
grid_search.fit(X, y)

# Access the n_splits_ attribute
n_splits = grid_search.n_splits_

# Print the number of splits used
print("Number of splits used in GridSearchCV:", n_splits)

Running the example gives an output like:

Number of splits used in GridSearchCV: 5

The key steps in this example are:

  1. Preparing a synthetic classification dataset using make_classification.
  2. Creating a RandomForestClassifier estimator and defining the parameter grid.
  3. Creating a GridSearchCV object with a specified cv value of 5.
  4. Fitting the GridSearchCV object on the synthetic dataset.
  5. Accessing the n_splits_ attribute from the fitted GridSearchCV object to verify the number of splits used in the grid search.


See Also