SKLearner Home | About | Contact | Examples

Scikit-Learn MultiOutputClassifier Model

The MultiOutputClassifier is a way to extend binary classification algorithms for multi-label classification problems. It trains a separate classifier for each target label, allowing the use of any estimator that supports binary classification.

Key hyperparameters include the base estimator (e.g., LogisticRegression) and n_jobs for parallelizing the fitting process across multiple CPUs.

This approach is appropriate when you have a multi-label classification problem, where each sample can belong to multiple classes simultaneously.

from sklearn.datasets import make_multilabel_classification
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# generate multi-label classification dataset
X, y = make_multilabel_classification(n_samples=100, n_classes=3, n_labels=2, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create multi-output classifier
estimator = LogisticRegression()
classifier = MultiOutputClassifier(estimator)

# fit classifier
classifier.fit(X_train, y_train)

# evaluate classifier
yhat = classifier.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[1, 2, 1, 7, 0, 1, 1, 4, 0, 3, 1, 3, 6, 8, 1, 0, 2, 1, 9, 2]]
yhat = classifier.predict(row)
print('Predicted: %s' % yhat[0])

Running the example gives an output like:

Accuracy: 0.500
Predicted: [0 1 0]

The steps are as follows:

  1. Generate a synthetic multi-label classification dataset using make_multilabel_classification(), specifying the desired number of samples (n_samples), classes (n_classes), labels per sample (n_labels), and a fixed random seed (random_state). Split the dataset into training and test sets.

  2. Create an instance of the base estimator (LogisticRegression) and pass it to the MultiOutputClassifier constructor to create the multi-output classifier.

  3. Fit the classifier on the training data using the fit() method.

  4. Evaluate the classifier’s performance by comparing the predicted labels (yhat) to the actual labels (y_test) using the accuracy_score metric.

  5. Demonstrate making a prediction on a new sample by passing it to the predict() method.

This example showcases how to use MultiOutputClassifier to extend a binary classifier for multi-label problems, enabling the use of familiar algorithms like LogisticRegression in more complex scenarios.



See Also