SKLearner Home | About | Contact | Examples

Configure ExtraTreesRegressor "warm_start" Parameter

The warm_start parameter in scikit-learn’s ExtraTreesRegressor allows for incremental learning, enabling the addition of more trees to an existing forest without retraining from scratch.

Extra Trees Regressor is an ensemble method that builds a forest of unpruned decision trees. It differs from Random Forest in its splitting strategy, using random thresholds for each feature rather than searching for the best possible thresholds.

The warm_start parameter, when set to True, reuses the solution of the previous call to fit and adds more estimators to the ensemble. This can significantly reduce training time when working with large datasets or when fine-tuning the number of trees.

By default, warm_start is set to False. It’s commonly used in scenarios where you want to incrementally grow the forest or when performing early stopping based on out-of-bag error.

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.metrics import mean_squared_error
import time

# Generate synthetic dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train without warm_start
start_time = time.time()
etr_cold = ExtraTreesRegressor(n_estimators=100, random_state=42)
etr_cold.fit(X_train, y_train)
cold_time = time.time() - start_time
cold_mse = mean_squared_error(y_test, etr_cold.predict(X_test))

# Train with warm_start
start_time = time.time()
etr_warm = ExtraTreesRegressor(n_estimators=10, warm_start=True, random_state=42)
warm_mse_list = []

for _ in range(10):
    etr_warm.fit(X_train, y_train)
    warm_mse_list.append(mean_squared_error(y_test, etr_warm.predict(X_test)))
    etr_warm.n_estimators += 10

warm_time = time.time() - start_time

print(f"Cold start time: {cold_time:.3f}s, MSE: {cold_mse:.3f}")
print(f"Warm start time: {warm_time:.3f}s, Final MSE: {warm_mse_list[-1]:.3f}")

Running the example gives an output like:

Cold start time: 0.220s, MSE: 2036.183
Warm start time: 0.273s, Final MSE: 2036.183
[Finished in 1.4s]

The key steps in this example are:

  1. Generate a synthetic regression dataset
  2. Split the data into train and test sets
  3. Train an ExtraTreesRegressor without warm_start (cold start)
  4. Train an ExtraTreesRegressor with warm_start, incrementally adding trees
  5. Compare training times and model performance for both approaches

Some tips for using warm_start:

Issues to consider:



See Also