SKLearner Home | About | Contact | Examples

Scikit-Learn confusion_matrix() Metric

The confusion matrix is a useful tool for evaluating the performance of classification models. It provides a tabular summary of the model’s predictions compared to the actual labels, showing the counts of true positives, true negatives, false positives, and false negatives for each class.

The confusion_matrix() function in scikit-learn takes the true labels and predicted labels as input and returns a square matrix where the rows represent the actual classes and the columns represent the predicted classes. The diagonal elements of the matrix indicate the number of correct predictions for each class, while the off-diagonal elements show the misclassifications.

Confusion matrices are applicable to both binary and multi-class classification problems. They offer a detailed breakdown of the model’s performance, allowing us to identify which classes the model performs well on and where it struggles. However, the confusion matrix does not provide a single aggregate metric like accuracy or F1-score.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=3, n_informative=4, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a random forest classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Predict on test set
y_pred = clf.predict(X_test)

# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)

Running the example gives an output like:

Confusion Matrix:
[[43 14  7]
 [ 1 56  5]
 [ 8 13 53]]
  1. Generate a synthetic multi-class classification dataset using make_classification() with 1000 samples and 3 classes.

  2. Split the dataset into training and test sets using train_test_split(), reserving 20% for testing.

  3. Train a RandomForestClassifier with 100 trees on the training set using the fit() method.

  4. Use the trained classifier to make predictions on the test set with predict().

  5. Calculate the confusion matrix using confusion_matrix() by passing the true labels (y_test) and predicted labels (y_pred).

  6. Print the confusion matrix to visualize the counts of correct and incorrect predictions for each class.

The confusion matrix provides a comprehensive overview of the classifier’s performance, enabling us to assess its strengths and weaknesses across different classes. By examining the matrix, we can identify patterns of misclassification and gain insights into potential areas for improvement in our model.



See Also