SGDClassifier
uses stochastic gradient descent for learning linear classifiers such as SVM and logistic regression. It is suitable for binary and multi-class classification problems.
The key hyperparameters of SGDClassifier
include the loss
(loss function to be used), penalty
(regularization term), and alpha
(constant that multiplies the regularization term).
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
# generate binary classification dataset
X, y = make_classification(n_samples=100, n_features=5, n_classes=2, random_state=1)
# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# create model
model = SGDClassifier()
# fit model
model.fit(X_train, y_train)
# evaluate model
yhat = model.predict(X_test)
acc = accuracy_score(y_test, yhat)
print('Accuracy: %.3f' % acc)
# make a prediction
row = [[-1.10325445, -0.49821356, -0.05962247, -0.89224592, -0.70158632]]
yhat = model.predict(row)
print('Predicted: %d' % yhat[0])
Running the example gives an output like:
Accuracy: 0.950
Predicted: 0
The steps are as follows:
Generate a synthetic binary classification dataset using
make_classification()
. This function creates a dataset with a specified number of samples (n_samples
), features (n_features
), and classes (n_classes
), and sets a fixed random seed (random_state
) for reproducibility. Split the dataset into training and test sets usingtrain_test_split()
.Instantiate an
SGDClassifier
model with default hyperparameters. Fit the model on the training data using thefit()
method.Evaluate the model’s performance by comparing the predictions (
yhat
) to the actual test values (y_test
) using the accuracy score metric.Make a single prediction by passing a new data sample to the
predict()
method.
This example demonstrates how to set up and use an SGDClassifier
for binary classification tasks, highlighting the ease of implementation and evaluation of this algorithm in scikit-learn.