SKLearner Home | About | Contact | Examples

Scikit-Learn NearestCentroid Model

Nearest Centroid is a simple and intuitive classification algorithm that assigns labels based on the nearest class centroid. It is useful for multi-class classification problems and is easy to understand and implement.

The key hyperparameters of NearestCentroid include metric, which defines the distance metric to be used. Common values for metric are ’euclidean’, ‘manhattan’, etc.

The algorithm is appropriate for multi-class classification problems.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestCentroid
from sklearn.metrics import accuracy_score

# generate a multi-class classification dataset
X, y = make_classification(n_samples=100, n_clusters_per_class=1, n_features=5, n_classes=3, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create model
model = NearestCentroid()

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.950
Predicted: 2

The steps are as follows:

  1. First, a synthetic multi-class classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), classes (n_classes), and a fixed random seed (random_state) for reproducibility. The dataset is split into training and test sets using train_test_split().

  2. Next, a NearestCentroid model is instantiated with default hyperparameters. The model is then fit on the training data using the fit() method.

  3. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  4. A single prediction can be made by passing a new data sample to the predict() method.

This example demonstrates how to quickly set up and use a NearestCentroid model for multi-class classification tasks, showcasing the simplicity and effectiveness of this algorithm in scikit-learn.

The model can be fit directly on the training data without the need for scaling or normalization. Once fit, the model can be used to make predictions on new data, enabling its use in real-world multi-class classification problems.



See Also