SKLearner Home | About | Contact | Examples

Scikit-Learn top_k_accuracy_score() Metric

Top-K Accuracy Score is a useful metric for evaluating the performance of classification models, particularly in multiclass scenarios. It measures the ratio of times the true label is among the top K predicted labels. For instance, if we set K=3, the metric checks whether the true label is among the top 3 predicted labels.

The top_k_accuracy_score() function in scikit-learn calculates this metric by comparing the true labels with the predicted probabilities. It returns a float value between 0 and 1, with values closer to 1 indicating better performance. A good Top-K Accuracy Score means the true label frequently appears in the top K predictions, which is especially useful in multiclass classification problems or when the cost of missing the exact class is high but being close is acceptable.

However, this metric may be less effective for binary classification and can be computationally intensive with a large number of classes.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import top_k_accuracy_score

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_classes=5, n_informative=10, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a neural network classifier
clf = MLPClassifier(random_state=42, max_iter=300)
clf.fit(X_train, y_train)

# Predict on test set
y_pred = clf.predict_proba(X_test)

# Calculate top-3 accuracy
top3_acc = top_k_accuracy_score(y_test, y_pred, k=3)
print(f"Top-3 Accuracy: {top3_acc:.2f}")

Running the example gives an output like:

Top-3 Accuracy: 0.97

The steps are as follows:

  1. Generate a synthetic multiclass classification dataset using make_classification().
  2. Split the dataset into training and test sets using train_test_split().
  3. Train an MLPClassifier on the training set.
  4. Use the trained classifier to predict probabilities for each class on the test set with predict_proba().
  5. Calculate the Top-3 accuracy using top_k_accuracy_score() by comparing the true labels to the predicted probabilities and checking if the true label is within the top 3 predicted probabilities.

First, we generate a synthetic multiclass classification dataset using the make_classification() function from scikit-learn. This function creates a dataset with 1000 samples and 5 classes, simulating a classification problem without using real-world data.

Next, we split the dataset into training and test sets using the train_test_split() function. This step is crucial for evaluating the performance of our classifier on unseen data. We use 80% of the data for training and reserve 20% for testing.

With our data prepared, we train a neural network classifier using the MLPClassifier class from scikit-learn. We set the random state for reproducibility and limit the maximum number of iterations to 300. The fit() method is called on the classifier object, passing in the training features (X_train) and labels (y_train) to learn the underlying patterns in the data.

After training, we use the trained classifier to predict probabilities for each class on the test set by calling the predict_proba() method with X_test. This generates predicted probabilities for each class for each sample in the test set.

Finally, we evaluate the Top-3 accuracy of our classifier using the top_k_accuracy_score() function. This function takes the true labels (y_test) and the predicted probabilities (y_pred) as input and calculates the ratio of times the true label is among the top 3 predicted labels. The resulting Top-3 accuracy score is printed, providing a quantitative measure of our classifier’s performance in a multiclass classification setting.

This example demonstrates how to use the top_k_accuracy_score() function from scikit-learn to evaluate the performance of a multiclass classification model.



See Also