SKLearner Home | About | Contact | Examples

Configure RandomForestClassifier "max_samples" Parameter

The max_samples parameter in scikit-learn’s RandomForestClassifier controls the number of samples to draw from X to train each base estimator (decision tree).

By default, max_samples is set to None, which means the entire dataset is used to train each tree. This can lead to trees that are highly correlated and prone to overfitting.

Setting max_samples to a value less than 1.0 (for a fraction) or an integer (for an absolute number) introduces additional randomness into the training process. This helps to create more diverse trees and can improve generalization performance.

Common values for max_samples are floats between 0.5 and 0.8, or integers representing a subset of the total number of samples.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
                           n_redundant=0, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with different max_samples values
max_samples_values = [None, 0.5, 0.8, 200]
accuracies = []

for max_samples in max_samples_values:
    rf = RandomForestClassifier(n_estimators=100, max_samples=max_samples, random_state=42)
    rf.fit(X_train, y_train)
    y_pred = rf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)
    print(f"max_samples={max_samples}, Accuracy: {accuracy:.3f}")

Running the example gives an output like:

max_samples=None, Accuracy: 0.920
max_samples=0.5, Accuracy: 0.925
max_samples=0.8, Accuracy: 0.925
max_samples=200, Accuracy: 0.905

The key steps in this example are:

  1. Generate a synthetic binary classification dataset
  2. Split the data into train and test sets
  3. Train RandomForestClassifier models with different max_samples values
  4. Evaluate the accuracy of each model on the test set

Some tips and heuristics for setting max_samples:

Issues to consider:



See Also