SKLearner Home | About | Contact | Examples

Get LogisticRegression "classes_" Attribute

The LogisticRegression classifier in scikit-learn is a linear model for binary classification.

The classes_ attribute of a fitted LogisticRegression object stores the unique classes found in the target variable during fitting.

Accessing the classes_ attribute is useful for understanding the class labels the model is working with, especially in a multiclass setting. This can help in interpreting the model’s predictions and ensuring the target labels are correctly identified.

from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# Generate a synthetic binary classification dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=3, n_informative=10, random_state=42)

# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize a LogisticRegression classifier
lr = LogisticRegression(multi_class='multinomial', solver='lbfgs', random_state=42)

# Fit the classifier on the training data
lr.fit(X_train, y_train)

# Access the classes_ attribute and print the class labels
class_labels = lr.classes_
print(f"Class labels: {class_labels}")

# Use the class labels in predictions
predictions = lr.predict(X_test)
print(f"Predictions: {predictions[:10]}")

Running the example gives an output like:

Class labels: [0 1 2]
Predictions: [2 1 2 2 1 0 0 1 1 2]

The key steps in this example are:

  1. Generate a synthetic dataset with make_classification, creating a multiclass problem with 3 classes.
  2. Split the dataset into training and test sets using train_test_split.
  3. Initialize a LogisticRegression classifier with appropriate parameters for multiclass classification.
  4. Fit the classifier on the training data.
  5. Access the classes_ attribute to retrieve the unique class labels and store them in the class_labels variable.
  6. Print the class labels to verify the output.
  7. Make predictions on the test set using the fitted model and print the first ten predictions to demonstrate the use of class labels in practice.


See Also