SKLearner Home | About | Contact | Examples

Scikit-Learn SelfTrainingClassifier Model

SelfTrainingClassifier is a meta-estimator for semi-supervised learning, allowing the use of unlabeled data to improve a supervised learning classifier. It wraps a base classifier and iteratively trains it on labeled and unlabeled data, adding the most confident predictions to the training set.

The key hyperparameters of SelfTrainingClassifier include the base_estimator (the base classifier to be trained), threshold (the confidence threshold for adding predictions to the training set), and criterion (the selection criterion for adding predictions).

The algorithm is appropriate for classification problems where labeled data is scarce, but unlabeled data is abundant.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.metrics import accuracy_score

# generate synthetic dataset with labeled and unlabeled samples
X, y = make_classification(n_samples=500, n_classes=2, n_informative=2, n_redundant=0, random_state=42)
unlabeled_samples = 250
y[unlabeled_samples:] = -1

# split labeled data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X[:unlabeled_samples], y[:unlabeled_samples], test_size=0.2, random_state=42)

# create base classifier
base_classifier = LogisticRegression(random_state=42)

# instantiate SelfTrainingClassifier
self_training_clf = SelfTrainingClassifier(base_classifier, threshold=0.75)

# fit SelfTrainingClassifier on labeled and unlabeled data
self_training_clf.fit(X, y)

# evaluate model
y_pred = self_training_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.3f}")

# make predictions on new, unlabeled samples
new_sample = [[-1.09891625, 1.6728682, 1.57239821, -0.86397985, -0.20020125, -0.90690287,
               -0.60726141, -0.12653258, -2.35045585, -0.45389142, -0.94895283, 1.53914235,
               -2.09968164, 0.24785287, -0.08854421, 1.2307029, 0.89005773, 1.39922795,
                0.68742676, -0.28212578]]
print(f"Predictions: {self_training_clf.predict(new_sample)}")

Running the example gives an output like:

Accuracy: 0.940
Predictions: [0]

The steps are as follows:

  1. First, a synthetic dataset is generated using make_classification(), with a portion of the samples designated as unlabeled by setting their labels to -1.

  2. The labeled data is split into training and test sets using train_test_split().

  3. A base classifier, LogisticRegression, is created.

  4. SelfTrainingClassifier is instantiated, wrapping the base classifier and specifying a confidence threshold of 0.75.

  5. The SelfTrainingClassifier is fit on both the labeled and unlabeled data using the fit() method.

  6. The performance of the model is evaluated on the labeled test set using the accuracy score metric.

  7. The trained model is used to make predictions on new, unlabeled samples.

This example demonstrates how to leverage unlabeled data using SelfTrainingClassifier to improve a supervised learning classifier. By iteratively adding confident predictions to the training set, the model can learn from a larger dataset and potentially improve its performance.



See Also