SKLearner Home | About | Contact | Examples

Configure DecisionTreeClassifier "class_weight" Parameter

The class_weight parameter in scikit-learn’s DecisionTreeClassifier allows you to handle imbalanced datasets where the classes have very different numbers of samples.

Imbalanced datasets can lead to biased models that perform poorly on the minority class, as the model may simply learn to always predict the majority class.

The class_weight parameter allows you to assign higher weights to the minority class samples during training, effectively increasing their importance and compensating for the class imbalance.

The default value for class_weight is None, which gives equal weight to all classes.

A common value is ‘balanced’, which automatically adjusts the class weights inversely proportional to their frequencies in the input data.

You can also specify a dictionary to set custom weights for each class.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Generate imbalanced synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=2, weights=[0.9, 0.1],
                           random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train default and balanced DecisionTreeClassifier models
dt_default = DecisionTreeClassifier(random_state=42)
dt_balanced = DecisionTreeClassifier(class_weight='balanced', random_state=42)

dt_default.fit(X_train, y_train)
dt_balanced.fit(X_train, y_train)

# Evaluate performance metrics
y_pred_default = dt_default.predict(X_test)
y_pred_balanced = dt_balanced.predict(X_test)

print("Default Model:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_default):.3f}")
print(f"Precision: {precision_score(y_test, y_pred_default):.3f}")
print(f"Recall: {recall_score(y_test, y_pred_default):.3f}")
print(f"F1 Score: {f1_score(y_test, y_pred_default):.3f}\n")

print("Balanced Model:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_balanced):.3f}")
print(f"Precision: {precision_score(y_test, y_pred_balanced):.3f}")
print(f"Recall: {recall_score(y_test, y_pred_balanced):.3f}")
print(f"F1 Score: {f1_score(y_test, y_pred_balanced):.3f}")

Running the example gives an output like:

Default Model:
Accuracy: 0.965
Precision: 0.933
Recall: 0.700
F1 Score: 0.800

Balanced Model:
Accuracy: 0.935
Precision: 0.667
Recall: 0.700
F1 Score: 0.683

The key steps in this example are:

  1. Generate an imbalanced binary classification dataset using make_classification
  2. Split the data into train and test sets
  3. Train a default DecisionTreeClassifier and one with class_weight='balanced'
  4. Evaluate and compare performance metrics like accuracy, precision, recall, and F1 score

Some tips and heuristics for using class_weight:

Issues to consider:



See Also