SKLearner Home | About | Contact | Examples

Get RandomForestClassifier "n_features_in_" Attribute

The RandomForestClassifier in scikit-learn is an ensemble learning algorithm that constructs a multitude of decision trees for classification tasks. It builds a forest of trees, where each tree is trained on a bootstrap sample of the original training data. The final prediction is obtained by aggregating the predictions of all the individual trees.

The n_features_in_ attribute of a fitted RandomForestClassifier stores the number of features seen during the fit() method. This attribute is useful for checking the consistency of the input data and understanding the complexity of the model.

Knowing the number of input features can be helpful in various scenarios. For example, when applying the trained model to new data, you can verify that the new data has the same number of features as the training data. Additionally, the number of features can give you an idea about the dimensionality of the problem and the potential risk of overfitting.

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Generate a synthetic multiclass classification dataset
X, y = make_classification(n_samples=1000, n_classes=3, n_features=10, n_informative=8,
                           n_redundant=2, random_state=42)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize a RandomForestClassifier with default hyperparameters
rf = RandomForestClassifier(random_state=42)

# Fit the classifier on the training data
rf.fit(X_train, y_train)

# Access the n_features_in_ attribute and print its value
print(f"Number of features seen during fit: {rf.n_features_in_}")

Running the example gives an output like:

Number of features seen during fit: 10

The key steps in this example are:

  1. Generate a synthetic multiclass classification dataset using make_classification and split it into train and test sets.
  2. Initialize a RandomForestClassifier with default hyperparameters and fit it on the training data.
  3. Access the n_features_in_ attribute of the fitted classifier and print its value.


See Also