SKLearner Home | About | Contact | Examples

Scikit-Learn ClassifierChain Model

ClassifierChain is an effective method for multi-label classification, where each instance can belong to multiple classes simultaneously. It leverages interdependencies between labels to improve performance.

The key hyperparameters of ClassifierChain include the base estimator (classifier), the order of classifiers, and whether to use the original features along with the predictions as input for subsequent classifiers.

The algorithm is appropriate for multi-label classification problems.

from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import ClassifierChain
from sklearn.metrics import accuracy_score

# generate multi-label classification dataset
X, y = make_multilabel_classification(n_samples=100, n_features=5, n_classes=3, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create base model
base_model = LogisticRegression()

# create classifier chain
model = ClassifierChain(base_model)

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(row)
print('Predicted: %s' % yhat[0])

Running the example gives an output like:

Accuracy: 0.650
Predicted: [0. 0. 0.]

The steps are as follows:

  1. First, a synthetic multi-label classification dataset is generated using the make_multilabel_classification() function. This creates a dataset with a specified number of samples (n_samples), features (n_features), classes (n_classes), and a fixed random seed (random_state) for reproducibility. The dataset is split into training and test sets using train_test_split().

  2. Next, a LogisticRegression model is instantiated as the base classifier for the chain. The ClassifierChain is then created using this base model.

  3. The chain model is fit on the training data using the fit() method. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  4. A single prediction can be made by passing a new data sample to the predict() method of the chain model.

This example demonstrates how to set up and use a ClassifierChain for multi-label classification tasks, highlighting the ability to handle multiple labels per instance effectively in scikit-learn. The model leverages label dependencies to improve prediction accuracy, making it useful for complex classification problems.



See Also