SKLearner Home | About | Contact | Examples

Scikit-Learn OneVsRestClassifier Model

OneVsRestClassifier is a strategy for handling multi-class classification problems by fitting one classifier per class. It predicts using the classifier with the highest confidence score.

The key parameter of OneVsRestClassifier is the estimator, which specifies the base classifier to use. Common choices include LogisticRegression or SVC.

The algorithm is appropriate for multi-class classification problems.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# generate multi-class classification dataset
X, y = make_classification(n_samples=100, n_features=20, n_classes=3, n_informative=10, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create base classifier
base_model = LogisticRegression()

# create OneVsRestClassifier model
model = OneVsRestClassifier(base_model)

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[-2.343, 0.547, 1.144, -1.537, 0.922, -0.982, 1.402, -0.344, 0.334, -0.403,
        0.732, -0.673, -0.523, 0.684, 0.799, -0.679, 0.949, -0.512, 1.204, -0.434]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.550
Predicted: 2

The steps are as follows:

  1. First, a synthetic multi-class classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), features (n_features), and classes (n_classes). The dataset is split into training and test sets using train_test_split().

  2. Next, a LogisticRegression model is instantiated as the base classifier. This is passed to OneVsRestClassifier to create the multi-class classification model.

  3. The OneVsRestClassifier model is fit on the training data using the fit() method.

  4. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  5. A single prediction can be made by passing a new data sample to the predict() method.

This example demonstrates how to set up and use a OneVsRestClassifier model for multi-class classification tasks, showcasing its ability to leverage simple binary classifiers for more complex problems.



See Also