SKLearner Home | About | Contact | Examples

Configure MLPClassifier "warm_start" Parameter

The warm_start parameter in scikit-learn’s MLPClassifier determines whether to reuse the solution of the previous call to fit as initialization for the next fit.

MLPClassifier is a multi-layer perceptron neural network model used for classification tasks. It learns non-linear decision boundaries using backpropagation.

When warm_start is set to True, the model retains the learned weights from previous training sessions. This allows for incremental learning, where you can continue training the model on new data without starting from scratch.

The default value for warm_start is False, which means the model starts with fresh random weights each time it’s trained.

In practice, warm_start=True is used when you want to train the model incrementally or fine-tune an existing model with new data.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
                           n_redundant=0, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initial training without warm_start
mlp = MLPClassifier(hidden_layer_sizes=(10,), max_iter=100, random_state=42)
mlp.fit(X_train[:500], y_train[:500])
y_pred = mlp.predict(X_test)
print(f"Accuracy without warm_start: {accuracy_score(y_test, y_pred):.3f}")

# Continue training with warm_start
mlp.set_params(warm_start=True)
mlp.fit(X_train[500:], y_train[500:])
y_pred = mlp.predict(X_test)
print(f"Accuracy with warm_start: {accuracy_score(y_test, y_pred):.3f}")

# Train from scratch on full dataset
mlp_full = MLPClassifier(hidden_layer_sizes=(10,), max_iter=100, random_state=42)
mlp_full.fit(X_train, y_train)
y_pred = mlp_full.predict(X_test)
print(f"Accuracy training on full dataset: {accuracy_score(y_test, y_pred):.3f}")

Running the example gives an output like:

Accuracy without warm_start: 0.785
Accuracy with warm_start: 0.785
Accuracy training on full dataset: 0.805

The key steps in this example are:

  1. Generate a synthetic classification dataset
  2. Split the data into train and test sets
  3. Train an initial MLPClassifier on a subset of the training data
  4. Continue training the model with warm_start=True on the remaining data
  5. Compare the performance with a model trained from scratch on the full dataset

Some tips and heuristics for using warm_start:

Issues to consider:



See Also