SKLearner Home | About | Contact | Examples

Scikit-Learn LabelSpreading Model

Label Spreading is a semi-supervised learning algorithm that propagates labels from labeled to unlabeled data points. It is useful when you have a dataset with a small number of labeled instances and a large number of unlabeled instances.

The key hyperparameters of LabelSpreading include the kernel (kernel function), gamma (kernel coefficient), n_neighbors (number of neighbors), alpha (clamping factor), and max_iter (maximum number of iterations).

The algorithm is appropriate for classification problems with partially labeled data.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.semi_supervised import LabelSpreading
from sklearn.metrics import accuracy_score

# generate partially labeled dataset
X, y = make_classification(n_samples=200, n_features=5, n_classes=2, n_informative=2, n_redundant=0, random_state=1)
y[100:160] = -1  # set some labels to -1 to represent unlabeled data

# split labeled data into train and test sets
X_train_labeled = X[y != -1]
y_train_labeled = y[y != -1]
X_train, X_test, y_train, y_test = train_test_split(X_train_labeled, y_train_labeled, test_size=0.2, random_state=1)

# create model
model = LabelSpreading(gamma=0.25, max_iter=20)

# fit model on labeled and unlabeled data
model.fit(X, y)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction on a new, unlabeled data point
new_data_point = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(new_data_point)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.857
Predicted: 0

The steps are as follows:

  1. First, a synthetic partially labeled dataset is generated using the make_classification() function. Some of the labels are then set to -1 to represent unlabeled data points.

  2. The labeled data is split into training and test sets using train_test_split().

  3. A LabelSpreading model is instantiated with chosen hyperparameters. The model is then fit on both the labeled and unlabeled data using the fit() method.

  4. The performance of the model is evaluated by comparing the predictions (yhat) on the test set to the actual values (y_test) using the accuracy score metric.

  5. A prediction is made on a new, unlabeled data point by passing it to the predict() method.

This example demonstrates how to use the LabelSpreading model for semi-supervised learning tasks, leveraging both labeled and unlabeled data to improve classification performance.



See Also