The class_weight
parameter in scikit-learn’s DecisionTreeClassifier
allows you to handle imbalanced datasets where the classes have very different numbers of samples.
Imbalanced datasets can lead to biased models that perform poorly on the minority class, as the model may simply learn to always predict the majority class.
The class_weight
parameter allows you to assign higher weights to the minority class samples during training, effectively increasing their importance and compensating for the class imbalance.
The default value for class_weight
is None, which gives equal weight to all classes.
A common value is ‘balanced’, which automatically adjusts the class weights inversely proportional to their frequencies in the input data.
You can also specify a dictionary to set custom weights for each class.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# Generate imbalanced synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1],
random_state=42)
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train default and balanced DecisionTreeClassifier models
dt_default = DecisionTreeClassifier(random_state=42)
dt_balanced = DecisionTreeClassifier(class_weight='balanced', random_state=42)
dt_default.fit(X_train, y_train)
dt_balanced.fit(X_train, y_train)
# Evaluate performance metrics
y_pred_default = dt_default.predict(X_test)
y_pred_balanced = dt_balanced.predict(X_test)
print("Default Model:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_default):.3f}")
print(f"Precision: {precision_score(y_test, y_pred_default):.3f}")
print(f"Recall: {recall_score(y_test, y_pred_default):.3f}")
print(f"F1 Score: {f1_score(y_test, y_pred_default):.3f}\n")
print("Balanced Model:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_balanced):.3f}")
print(f"Precision: {precision_score(y_test, y_pred_balanced):.3f}")
print(f"Recall: {recall_score(y_test, y_pred_balanced):.3f}")
print(f"F1 Score: {f1_score(y_test, y_pred_balanced):.3f}")
Running the example gives an output like:
Default Model:
Accuracy: 0.965
Precision: 0.933
Recall: 0.700
F1 Score: 0.800
Balanced Model:
Accuracy: 0.935
Precision: 0.667
Recall: 0.700
F1 Score: 0.683
The key steps in this example are:
- Generate an imbalanced binary classification dataset using
make_classification
- Split the data into train and test sets
- Train a default
DecisionTreeClassifier
and one withclass_weight='balanced'
- Evaluate and compare performance metrics like accuracy, precision, recall, and F1 score
Some tips and heuristics for using class_weight
:
- Use
class_weight
when you have imbalanced classes and care about performance on the minority class - ‘balanced’ mode is often a good starting point, then adjust weights as needed based on your specific problem
Issues to consider:
- Severely imbalanced datasets may require more advanced techniques like oversampling in addition to class weighting
- The ‘balanced’ heuristic is based on class frequency, which may not be the optimal set of weights for your specific problem