SKLearner Home | About | Contact | Examples

Scikit-Learn OrthogonalMatchingPursuitCV Model

Orthogonal Matching Pursuit (OMP) is a regression algorithm used for high-dimensional data. OrthogonalMatchingPursuitCV includes cross-validation to automatically select the best number of non-zero coefficients.

The key hyperparameters of OrthogonalMatchingPursuitCV include cv (number of cross-validation folds) and fit_intercept (whether to fit the intercept). The algorithm is suitable for regression problems where feature selection is crucial.

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import OrthogonalMatchingPursuitCV
from sklearn.metrics import mean_squared_error

# generate synthetic regression dataset
X, y = make_regression(n_samples=100, n_features=20, noise=0.1, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create model
model = OrthogonalMatchingPursuitCV(cv=5)

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
mse = mean_squared_error(y_test, yhat)
print('Mean Squared Error: %.3f' % mse)

# make a prediction
row = [0.5] * 20
yhat = model.predict([row])
print('Predicted: %.3f' % yhat[0])

Running the example gives an output like:

Mean Squared Error: 1678.366
Predicted: 187.995

The steps are as follows:

  1. A synthetic regression dataset is generated using make_regression(), specifying the number of samples (n_samples), features (n_features), and noise level (noise), with a fixed random seed for reproducibility. The dataset is split into training and test sets using train_test_split().

  2. An OrthogonalMatchingPursuitCV model is instantiated with cross-validation folds set to 5 (cv=5).

  3. The model is fit on the training data using the fit() method.

  4. Model performance is evaluated by comparing predictions (yhat) to actual values (y_test) using the mean squared error metric.

  5. A single prediction is made by passing a new data sample to the predict() method.

This example demonstrates how to set up and use the OrthogonalMatchingPursuitCV model for regression tasks, highlighting its utility in feature selection and cross-validation in scikit-learn.



See Also