SKLearner Home | About | Contact | Examples

Scikit-Learn SplineTransformer for Data Preprocessing

SplineTransformer is a data transformation technique that uses spline functions to map original features to new features.

Important hyperparameters include n_knots (number of knots), degree (degree of the spline), and knots (location of the knots).

This algorithm is suitable for regression problems where the relationship between variables is non-linear.

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import SplineTransformer
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error

# generate synthetic regression dataset
X, y = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=1)

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

# create pipeline with spline transformer and linear regression
model = make_pipeline(SplineTransformer(n_knots=5, degree=3), LinearRegression())

# fit model
model.fit(X_train, y_train)

# evaluate model
yhat = model.predict(X_test)
mse = mean_squared_error(y_test, yhat)
print('Mean Squared Error: %.3f' % mse)

# make a prediction
row = [[0.5]]
yhat = model.predict(row)
print('Predicted: %.3f' % yhat[0])

Running the example gives an output like:

Mean Squared Error: 0.010
Predicted: 40.357

The steps are as follows:

This example demonstrates how to use the SplineTransformer to transform data and improve the performance of a linear regression model on non-linear relationships. The transformed features allow the linear model to capture non-linear patterns in the data.



See Also