SKLearner Home | About | Contact | Examples

Scikit-Learn LinearDiscriminantAnalysis Model

Linear Discriminant Analysis (LDA) is a classification algorithm used to find a linear combination of features that best separates multiple classes. It is commonly used for dimensionality reduction while preserving class separability.

The key hyperparameters of LinearDiscriminantAnalysis include the solver (the algorithm used to solve the LDA problem) and shrinkage (a regularization term that can improve stability with high-dimensional datasets).

The algorithm is appropriate for classification tasks, particularly binary and multi-class classification.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score

# generate multi-class classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=3, n_informative=3, n_redundant=0, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create model
model = LinearDiscriminantAnalysis()

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)

# make a prediction
row = [[-0.331073, -0.851501, 1.065461, 0.689164, -1.201857]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])

Running the example gives an output like:

Accuracy: 0.750
Predicted: 1

The steps are as follows:

  1. First, a synthetic multi-class classification dataset is generated using the make_classification() function. This creates a dataset with a specified number of samples (n_samples), informative features (n_informative), and classes (n_classes). The dataset is split into training and test sets using train_test_split().

  2. Next, a LinearDiscriminantAnalysis model is instantiated with default hyperparameters. The model is then fit on the training data using the fit() method.

  3. The performance of the model is evaluated by comparing the predictions (yhat) to the actual values (y_test) using the accuracy score metric.

  4. A single prediction can be made by passing a new data sample to the predict() method.

This example demonstrates how to quickly set up and use a LinearDiscriminantAnalysis model for multi-class classification tasks, showcasing the simplicity and effectiveness of this algorithm in scikit-learn.



See Also