SKLearner Home | About | Contact | Examples

Configure ExtraTreesClassifier "max_features" Parameter

The max_features parameter in scikit-learn’s ExtraTreesClassifier controls the number of features to consider when looking for the best split.

Extra Trees Classifier is an ensemble method that builds multiple randomized decision trees and combines their predictions. The max_features parameter introduces additional randomness in the tree-building process by limiting the number of features considered at each split.

Adjusting max_features affects the diversity of the trees in the ensemble. Lower values increase randomness and can help prevent overfitting, while higher values allow the algorithm to consider more features, potentially capturing more complex relationships.

The default value for max_features is ‘sqrt’, which uses the square root of the total number of features. Common options include ‘sqrt’, ’log2’, or a fraction of the total features (e.g., 0.5).

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_informative=10,
                           n_redundant=5, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with different max_features values
max_features_values = ['sqrt', 'log2', 0.5, 0.8, None]
accuracies = []

for mf in max_features_values:
    etc = ExtraTreesClassifier(n_estimators=100, max_features=mf, random_state=42)
    etc.fit(X_train, y_train)
    y_pred = etc.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)
    print(f"max_features={mf}, Accuracy: {accuracy:.3f}")

Running the example gives an output like:

max_features=sqrt, Accuracy: 0.925
max_features=log2, Accuracy: 0.925
max_features=0.5, Accuracy: 0.930
max_features=0.8, Accuracy: 0.925
max_features=None, Accuracy: 0.930

The key steps in this example are:

  1. Generate a synthetic classification dataset with informative and redundant features
  2. Split the data into train and test sets
  3. Train ExtraTreesClassifier models with different max_features values
  4. Evaluate the accuracy of each model on the test set

Some tips and heuristics for setting max_features:

Issues to consider:



See Also