The GridSearchCV
class in scikit-learn is commonly used for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and performs cross-validation to evaluate the model’s performance for each combination of those values.
The n_splits_
attribute of the GridSearchCV
object determines the number of splits used in the cross-validation process. By default, GridSearchCV
uses the value of the cv
parameter to determine the number of splits. However, you can explicitly set the cv
parameter to control the number of splits used.
Accessing the n_splits_
attribute allows you to verify the actual number of splits used in the grid search process. This can be useful for understanding the cross-validation strategy and ensuring that the desired number of splits is being used.
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)
# Create a RandomForestClassifier estimator
rf = RandomForestClassifier(random_state=42)
# Define the parameter grid
param_grid = {
'n_estimators': [50, 100, 200],
'max_depth': [None, 5, 10],
'min_samples_split': [2, 5, 10]
}
# Create a GridSearchCV object with a specified cv value
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)
# Fit the GridSearchCV object
grid_search.fit(X, y)
# Access the n_splits_ attribute
n_splits = grid_search.n_splits_
# Print the number of splits used
print("Number of splits used in GridSearchCV:", n_splits)
Running the example gives an output like:
Number of splits used in GridSearchCV: 5
The key steps in this example are:
- Preparing a synthetic classification dataset using
make_classification
. - Creating a
RandomForestClassifier
estimator and defining the parameter grid. - Creating a
GridSearchCV
object with a specifiedcv
value of 5. - Fitting the
GridSearchCV
object on the synthetic dataset. - Accessing the
n_splits_
attribute from the fittedGridSearchCV
object to verify the number of splits used in the grid search.