SKLearner Home | About | Contact | Examples

Scikit-Learn MultiTaskElasticNet Model

MultiTaskElasticNet is a linear model trained with L1 and L2 prior as regularizer, suitable for multi-output regression problems. It solves multiple regression problems simultaneously, providing a balance between the L1 and L2 penalties on the coefficients.

The key hyperparameters of MultiTaskElasticNet include alpha, which multiplies the penalty terms, and l1_ratio, the ElasticNet mixing parameter that ranges from 0 to 1.

This algorithm is appropriate for multi-output regression problems.

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.metrics import mean_squared_error

# generate multi-output regression dataset
X, y = make_regression(n_samples=100, n_features=5, n_targets=3, noise=0.1, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create model
model = MultiTaskElasticNet(alpha=1.0, l1_ratio=0.5)

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
mse = mean_squared_error(y_test, yhat)
print('Mean Squared Error: %.3f' % mse)

# make a prediction
row = [[-0.1382643, 0.64768854, 1.52302986, -0.23415337, -0.23413696]]
yhat = model.predict(row)
print('Predicted: %s' % yhat)

Running the example gives an output like:

Mean Squared Error: 1201.401
Predicted: [[16.76234334 14.75708437 34.49206694]]

The steps are as follows:

  1. Generate a synthetic multi-output regression dataset using the make_regression() function with specified number of samples (n_samples), features (n_features), and target variables (n_targets). Ensure reproducibility with a fixed random seed (random_state). Split the dataset into training and test sets using train_test_split().

  2. Instantiate the MultiTaskElasticNet model with hyperparameters alpha and l1_ratio. The model is fit on the training data using the fit() method.

  3. Evaluate the model by predicting on the test set and comparing the predictions (yhat) to the actual values (y_test) using the mean squared error metric.

  4. Make a prediction for a new data sample using the predict() method, showcasing how to use the trained model for future predictions.



See Also