SKLearner Home | About | Contact | Examples

Scikit-Learn ExtraTreesClassifier Model

ExtraTreesClassifier (Extremely Randomized Trees) is an ensemble method that constructs multiple decision trees and aggregates their results. Unlike Random Forests, it selects splits more randomly, which can improve generalization.

The key hyperparameters of ExtraTreesClassifier include the n_estimators (number of trees in the forest), max_features (number of features to consider when looking for the best split), and min_samples_split (minimum number of samples required to split an internal node).

The algorithm is appropriate for binary and multi-class classification problems.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score

# generate synthetic classification dataset
X, y = make_classification(n_samples=100, n_features=10, n_classes=2, random_state=42)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# create the ExtraTreesClassifier model
model = ExtraTreesClassifier(n_estimators=100, random_state=42)

# fit the model
model.fit(X_train, y_train)

# evaluate the model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[0.5, -1.2, 0.3, 0.8, -0.9, 1.0, 0.5, -0.4, 0.7, -0.5]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 1.000
Predicted: 0

The steps are as follows:

  1. First, a synthetic binary classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), features (n_features), and a fixed random seed (random_state) for reproducibility. The dataset is split into training and test sets using train_test_split().

  2. Next, an ExtraTreesClassifier model is instantiated with 100 trees (n_estimators) and a fixed random seed (random_state). The model is then fit on the training data using the fit() method.

  3. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  4. A single prediction can be made by passing a new data sample to the predict() method.

This example demonstrates how to quickly set up and use an ExtraTreesClassifier model for binary classification tasks, showcasing the simplicity and effectiveness of this algorithm in scikit-learn. The model can be fit directly on the training data without the need for scaling or normalization. Once fit, the model can be used to make predictions on new data, enabling its use in real-world binary classification problems.



See Also