The warm_start
parameter in scikit-learn’s SGDClassifier
enables incremental learning, allowing the model to continue training from its current state rather than resetting weights.
Stochastic Gradient Descent (SGD) is an online learning algorithm that updates model parameters based on individual training samples or mini-batches. SGDClassifier
implements this approach for classification tasks.
When warm_start
is set to True
, the model retains its learned coefficients from previous fit
calls. This allows for incremental training on new data without losing previously learned information.
The default value for warm_start
is False
, which means the model resets its coefficients at each fit
call.
warm_start
is commonly used in scenarios with large datasets or streaming data, where training occurs in batches or as new data becomes available.
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
# Generate synthetic dataset
X, y = make_classification(n_samples=10000, n_features=20, n_classes=2, random_state=42)
# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create two SGDClassifier instances
sgd_warm = SGDClassifier(warm_start=True, random_state=42)
sgd_cold = SGDClassifier(warm_start=False, random_state=42)
# Train incrementally on batches
batch_size = 1000
n_batches = len(X_train) // batch_size
for i in range(n_batches):
X_batch = X_train[i*batch_size:(i+1)*batch_size]
y_batch = y_train[i*batch_size:(i+1)*batch_size]
sgd_warm.partial_fit(X_batch, y_batch, classes=np.unique(y))
sgd_cold.partial_fit(X_batch, y_batch, classes=np.unique(y))
warm_acc = accuracy_score(y_test, sgd_warm.predict(X_test))
cold_acc = accuracy_score(y_test, sgd_cold.predict(X_test))
print(f"Batch {i+1}: Warm accuracy: {warm_acc:.4f}, Cold accuracy: {cold_acc:.4f}")
Running the example gives an output like:
Batch 1: Warm accuracy: 0.6485, Cold accuracy: 0.6485
Batch 2: Warm accuracy: 0.8450, Cold accuracy: 0.8450
Batch 3: Warm accuracy: 0.8530, Cold accuracy: 0.8530
Batch 4: Warm accuracy: 0.8410, Cold accuracy: 0.8410
Batch 5: Warm accuracy: 0.8280, Cold accuracy: 0.8280
Batch 6: Warm accuracy: 0.7690, Cold accuracy: 0.7690
Batch 7: Warm accuracy: 0.8530, Cold accuracy: 0.8530
Batch 8: Warm accuracy: 0.8210, Cold accuracy: 0.8210
The key steps in this example are:
- Generate a synthetic binary classification dataset
- Split the data into train and test sets
- Create two
SGDClassifier
instances, one withwarm_start=True
and one withwarm_start=False
- Train both models incrementally on batches of data
- Evaluate and compare the accuracy of both models after each batch
Some tips for using warm_start
effectively:
- Use
warm_start
when dealing with large datasets that don’t fit in memory - Combine with
partial_fit
for true online learning scenarios - Shuffle your data between epochs to prevent overfitting
Issues to consider:
- Learning rate may need adjustment for effective incremental learning
- Model performance can be sensitive to the order of training data
- Incremental learning may converge slower than batch learning in some cases