The GridSearchCV class in scikit-learn is commonly used for hyperparameter tuning and model selection. It allows you to define a grid of hyperparameter values and performs cross-validation to evaluate the model’s performance for each combination of those values.
The n_splits_ attribute of the GridSearchCV object determines the number of splits used in the cross-validation process. By default, GridSearchCV uses the value of the cv parameter to determine the number of splits. However, you can explicitly set the cv parameter to control the number of splits used.
Accessing the n_splits_ attribute allows you to verify the actual number of splits used in the grid search process. This can be useful for understanding the cross-validation strategy and ensuring that the desired number of splits is being used.
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
# Generate a synthetic classification dataset
X, y = make_classification(n_samples=1000, n_classes=2, random_state=42)
# Create a RandomForestClassifier estimator
rf = RandomForestClassifier(random_state=42)
# Define the parameter grid
param_grid = {
'n_estimators': [50, 100, 200],
'max_depth': [None, 5, 10],
'min_samples_split': [2, 5, 10]
}
# Create a GridSearchCV object with a specified cv value
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5)
# Fit the GridSearchCV object
grid_search.fit(X, y)
# Access the n_splits_ attribute
n_splits = grid_search.n_splits_
# Print the number of splits used
print("Number of splits used in GridSearchCV:", n_splits)
Running the example gives an output like:
Number of splits used in GridSearchCV: 5
The key steps in this example are:
- Preparing a synthetic classification dataset using
make_classification. - Creating a
RandomForestClassifierestimator and defining the parameter grid. - Creating a
GridSearchCVobject with a specifiedcvvalue of 5. - Fitting the
GridSearchCVobject on the synthetic dataset. - Accessing the
n_splits_attribute from the fittedGridSearchCVobject to verify the number of splits used in the grid search.