The class_weight
parameter in scikit-learn’s RandomForestClassifier
is used to handle imbalanced datasets where one class has significantly fewer samples than the other.
Random Forest is an ensemble learning method that combines multiple decision trees to improve classification performance. However, when trained on imbalanced data, it can be biased towards the majority class.
The class_weight
parameter allows assigning higher weights to the minority class during training, effectively increasing its importance. This can help the model better learn the patterns of the underrepresented class.
By default, class_weight
is set to None
, meaning all classes have equal weight. Common values are 'balanced'
, which automatically adjusts weights inversely proportional to class frequencies, or a dictionary specifying the weight for each class.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
# Generate imbalanced synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1],
n_features=10, n_informative=5, n_redundant=0,
random_state=42)
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define class weight options
class_weights = [None, 'balanced', {0: 1, 1: 10}]
for weight in class_weights:
rf = RandomForestClassifier(n_estimators=100, class_weight=weight, random_state=42)
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
print(f"class_weight={weight}")
print(f"Accuracy: {accuracy:.3f}, F1 Score: {f1:.3f}\n")
The output will be similar to:
class_weight=None
Accuracy: 0.960, F1 Score: 0.800
class_weight=balanced
Accuracy: 0.935, F1 Score: 0.629
class_weight={0: 1, 1: 10}
Accuracy: 0.935, F1 Score: 0.629
The key steps in this example are:
- Generate an imbalanced binary classification dataset
- Split the data into train and test sets
- Train
RandomForestClassifier
models with differentclass_weight
values - Evaluate the accuracy and F1 score of each model on the test set
Some tips and heuristics for setting class_weight
:
- Use
'balanced'
as a quick way to adjust weights inversely proportional to class frequencies - For more control, manually define the weights in a dictionary
- Higher weight for the minority class can improve its recall, but may reduce overall accuracy
Issues to consider:
- Extremely imbalanced datasets may require additional techniques like oversampling or undersampling
- Careful validation is needed to avoid overfitting to the minority class
- There is often a trade-off between precision and recall for each class when adjusting weights