SplineTransformer
is a data transformation technique that uses spline functions to map original features to new features.
Important hyperparameters include n_knots
(number of knots), degree
(degree of the spline), and knots
(location of the knots).
This algorithm is suitable for regression problems where the relationship between variables is non-linear.
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import SplineTransformer
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error
# generate synthetic regression dataset
X, y = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=1)
# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
# create pipeline with spline transformer and linear regression
model = make_pipeline(SplineTransformer(n_knots=5, degree=3), LinearRegression())
# fit model
model.fit(X_train, y_train)
# evaluate model
yhat = model.predict(X_test)
mse = mean_squared_error(y_test, yhat)
print('Mean Squared Error: %.3f' % mse)
# make a prediction
row = [[0.5]]
yhat = model.predict(row)
print('Predicted: %.3f' % yhat[0])
Running the example gives an output like:
Mean Squared Error: 0.010
Predicted: 40.357
The steps are as follows:
A synthetic regression dataset is generated using the
make_regression()
function, with a single feature, noise, and a fixed random state for reproducibility. The dataset is split into training and test sets usingtrain_test_split()
.A
SplineTransformer
is created withn_knots=5
anddegree=3
. This transformer is used within a pipeline along withLinearRegression
to form the model. The model is then trained on the training data using thefit()
method.The performance of the model is evaluated using the mean squared error metric by comparing the predictions (
yhat
) to the actual values (y_test
).A single prediction is made by passing a new data sample to the
predict()
method.
This example demonstrates how to use the SplineTransformer
to transform data and improve the performance of a linear regression model on non-linear relationships. The transformed features allow the linear model to capture non-linear patterns in the data.