The RandomForestClassifier
in scikit-learn is an ensemble learning algorithm that constructs a multitude of decision trees for classification tasks. It builds a forest of trees, where each tree is trained on a bootstrap sample of the original training data. The final prediction is obtained by aggregating the predictions of all the individual trees.
The n_features_in_
attribute of a fitted RandomForestClassifier
stores the number of features seen during the fit()
method. This attribute is useful for checking the consistency of the input data and understanding the complexity of the model.
Knowing the number of input features can be helpful in various scenarios. For example, when applying the trained model to new data, you can verify that the new data has the same number of features as the training data. Additionally, the number of features can give you an idea about the dimensionality of the problem and the potential risk of overfitting.
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Generate a synthetic multiclass classification dataset
X, y = make_classification(n_samples=1000, n_classes=3, n_features=10, n_informative=8,
n_redundant=2, random_state=42)
# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize a RandomForestClassifier with default hyperparameters
rf = RandomForestClassifier(random_state=42)
# Fit the classifier on the training data
rf.fit(X_train, y_train)
# Access the n_features_in_ attribute and print its value
print(f"Number of features seen during fit: {rf.n_features_in_}")
Running the example gives an output like:
Number of features seen during fit: 10
The key steps in this example are:
- Generate a synthetic multiclass classification dataset using
make_classification
and split it into train and test sets. - Initialize a
RandomForestClassifier
with default hyperparameters and fit it on the training data. - Access the
n_features_in_
attribute of the fitted classifier and print its value.