ClassifierChain is an effective method for multi-label classification, where each instance can belong to multiple classes simultaneously. It leverages interdependencies between labels to improve performance.
The key hyperparameters of ClassifierChain
include the base estimator (classifier), the order of classifiers, and whether to use the original features along with the predictions as input for subsequent classifiers.
The algorithm is appropriate for multi-label classification problems.
from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import ClassifierChain
from sklearn.metrics import accuracy_score
# generate multi-label classification dataset
X, y = make_multilabel_classification(n_samples=100, n_features=5, n_classes=3, random_state=1)
# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# create base model
base_model = LogisticRegression()
# create classifier chain
model = ClassifierChain(base_model)
# fit model
model.fit(X_train, y_train)
# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)
# make a prediction
row = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(row)
print('Predicted: %s' % yhat[0])
Running the example gives an output like:
Accuracy: 0.650
Predicted: [0. 0. 0.]
The steps are as follows:
First, a synthetic multi-label classification dataset is generated using the
make_multilabel_classification()
function. This creates a dataset with a specified number of samples (n_samples
), features (n_features
), classes (n_classes
), and a fixed random seed (random_state
) for reproducibility. The dataset is split into training and test sets usingtrain_test_split()
.Next, a
LogisticRegression
model is instantiated as the base classifier for the chain. TheClassifierChain
is then created using this base model.The chain model is fit on the training data using the
fit()
method. The performance of the model is evaluated by comparing the predictions (yhat
) to the actual values (y_test
) using the accuracy score metric.A single prediction can be made by passing a new data sample to the
predict()
method of the chain model.
This example demonstrates how to set up and use a ClassifierChain
for multi-label classification tasks, highlighting the ability to handle multiple labels per instance effectively in scikit-learn. The model leverages label dependencies to improve prediction accuracy, making it useful for complex classification problems.