MultiTaskElasticNet is a linear model trained with L1 and L2 prior as regularizer, suitable for multi-output regression problems. It solves multiple regression problems simultaneously, providing a balance between the L1 and L2 penalties on the coefficients.
The key hyperparameters of MultiTaskElasticNet
include alpha
, which multiplies the penalty terms, and l1_ratio
, the ElasticNet mixing parameter that ranges from 0 to 1.
This algorithm is appropriate for multi-output regression problems.
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.metrics import mean_squared_error
# generate multi-output regression dataset
X, y = make_regression(n_samples=100, n_features=5, n_targets=3, noise=0.1, random_state=1)
# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# create model
model = MultiTaskElasticNet(alpha=1.0, l1_ratio=0.5)
# fit model
model.fit(X_train, y_train)
# evaluate model
yhat = model.predict(X_test)
mse = mean_squared_error(y_test, yhat)
print('Mean Squared Error: %.3f' % mse)
# make a prediction
row = [[-0.1382643, 0.64768854, 1.52302986, -0.23415337, -0.23413696]]
yhat = model.predict(row)
print('Predicted: %s' % yhat)
Running the example gives an output like:
Mean Squared Error: 1201.401
Predicted: [[16.76234334 14.75708437 34.49206694]]
The steps are as follows:
Generate a synthetic multi-output regression dataset using the
make_regression()
function with specified number of samples (n_samples
), features (n_features
), and target variables (n_targets
). Ensure reproducibility with a fixed random seed (random_state
). Split the dataset into training and test sets usingtrain_test_split()
.Instantiate the
MultiTaskElasticNet
model with hyperparametersalpha
andl1_ratio
. The model is fit on the training data using thefit()
method.Evaluate the model by predicting on the test set and comparing the predictions (
yhat
) to the actual values (y_test
) using the mean squared error metric.Make a prediction for a new data sample using the
predict()
method, showcasing how to use the trained model for future predictions.