The max_samples
parameter in scikit-learn’s RandomForestClassifier
controls the number of samples to draw from X to train each base estimator (decision tree).
By default, max_samples
is set to None, which means the entire dataset is used to train each tree. This can lead to trees that are highly correlated and prone to overfitting.
Setting max_samples
to a value less than 1.0 (for a fraction) or an integer (for an absolute number) introduces additional randomness into the training process. This helps to create more diverse trees and can improve generalization performance.
Common values for max_samples
are floats between 0.5 and 0.8, or integers representing a subset of the total number of samples.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
n_redundant=0, random_state=42)
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train with different max_samples values
max_samples_values = [None, 0.5, 0.8, 200]
accuracies = []
for max_samples in max_samples_values:
rf = RandomForestClassifier(n_estimators=100, max_samples=max_samples, random_state=42)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
accuracies.append(accuracy)
print(f"max_samples={max_samples}, Accuracy: {accuracy:.3f}")
Running the example gives an output like:
max_samples=None, Accuracy: 0.920
max_samples=0.5, Accuracy: 0.925
max_samples=0.8, Accuracy: 0.925
max_samples=200, Accuracy: 0.905
The key steps in this example are:
- Generate a synthetic binary classification dataset
- Split the data into train and test sets
- Train
RandomForestClassifier
models with differentmax_samples
values - Evaluate the accuracy of each model on the test set
Some tips and heuristics for setting max_samples
:
- Smaller values (e.g., 0.5) introduce more randomness and can help reduce overfitting
- Larger values (e.g., 0.8) retain more information from the original dataset
- The optimal value depends on the size and complexity of the dataset
Issues to consider:
- Setting
max_samples
too low may degrade performance by introducing too much bias - There is a computational overhead to drawing samples for each tree during training