OneVsRestClassifier
is a strategy for handling multi-class classification problems by fitting one classifier per class. It predicts using the classifier with the highest confidence score.
The key parameter of OneVsRestClassifier
is the estimator
, which specifies the base classifier to use. Common choices include LogisticRegression
or SVC
.
The algorithm is appropriate for multi-class classification problems.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# generate multi-class classification dataset
X, y = make_classification(n_samples=100, n_features=20, n_classes=3, n_informative=10, random_state=1)
# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# create base classifier
base_model = LogisticRegression()
# create OneVsRestClassifier model
model = OneVsRestClassifier(base_model)
# fit model
model.fit(X_train, y_train)
# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)
# make a prediction
row = [[-2.343, 0.547, 1.144, -1.537, 0.922, -0.982, 1.402, -0.344, 0.334, -0.403,
0.732, -0.673, -0.523, 0.684, 0.799, -0.679, 0.949, -0.512, 1.204, -0.434]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])
Running the example gives an output like:
Accuracy: 0.550
Predicted: 2
The steps are as follows:
First, a synthetic multi-class classification dataset is generated using the
make_classification()
function. This creates a dataset with a specified number of samples (n_samples
), features (n_features
), and classes (n_classes
). The dataset is split into training and test sets usingtrain_test_split()
.Next, a
LogisticRegression
model is instantiated as the base classifier. This is passed toOneVsRestClassifier
to create the multi-class classification model.The
OneVsRestClassifier
model is fit on the training data using thefit()
method.The performance of the model is evaluated by comparing the predictions (
yhat
) to the actual values (y_test
) using the accuracy score metric.A single prediction can be made by passing a new data sample to the
predict()
method.
This example demonstrates how to set up and use a OneVsRestClassifier
model for multi-class classification tasks, showcasing its ability to leverage simple binary classifiers for more complex problems.