SKLearner Home | About | Contact | Examples

Configure MLPClassifier "nesterovs_momentum" Parameter

The nesterovs_momentum parameter in scikit-learn’s MLPClassifier controls whether to use Nesterov’s momentum during training.

Neural networks can be challenging to train due to the complexity of their loss landscapes. Momentum methods help accelerate training and overcome local minima. Nesterov’s momentum is an advanced form that provides faster convergence than classical momentum in many cases.

The nesterovs_momentum parameter determines whether to use Nesterov’s momentum (True) or classical momentum (False) during weight updates in the backpropagation algorithm.

By default, nesterovs_momentum is set to True. The alternative is to set it to False, which uses classical momentum instead.

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# Generate synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with Nesterov's momentum
mlp_nesterov = MLPClassifier(hidden_layer_sizes=(100,), max_iter=500, random_state=42, nesterovs_momentum=True)
mlp_nesterov.fit(X_train, y_train)

# Train without Nesterov's momentum
mlp_classic = MLPClassifier(hidden_layer_sizes=(100,), max_iter=500, random_state=42, nesterovs_momentum=False)
mlp_classic.fit(X_train, y_train)

# Evaluate models
y_pred_nesterov = mlp_nesterov.predict(X_test)
y_pred_classic = mlp_classic.predict(X_test)

print(f"Accuracy with Nesterov's momentum: {accuracy_score(y_test, y_pred_nesterov):.3f}")
print(f"Accuracy with classical momentum: {accuracy_score(y_test, y_pred_classic):.3f}")

# Plot learning curves
plt.plot(mlp_nesterov.loss_curve_, label="Nesterov's momentum")
plt.plot(mlp_classic.loss_curve_, label="Classical momentum")
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()

Running the example gives an output like:

Accuracy with Nesterov's momentum: 0.820
Accuracy with classical momentum: 0.820

Configure MLPClassifier “nesterovs_momentum” Parameter

The key steps in this example are:

  1. Generate a synthetic binary classification dataset
  2. Split the data into train and test sets
  3. Train two MLPClassifier models, one with Nesterov’s momentum and one without
  4. Evaluate the accuracy of each model on the test set
  5. Plot learning curves to visualize the training process

Tips for using Nesterov’s momentum:

Issues to consider:



See Also