SKLearner Home | About | Contact | Examples

Scikit-Learn TimeSeriesSplit Data Splitting

TimeSeriesSplit is a cross-validation technique designed for time series data. It splits the data into train and test sets while preserving the temporal order, which is crucial for evaluating machine learning models on time series problems.

The key hyperparameter of TimeSeriesSplit is n_splits, which determines the number of splitting iterations. The max_train_size parameter can be used to limit the size of the train set.

This cross-validation method is appropriate for regression and classification problems involving time series data.

from sklearn.model_selection import TimeSeriesSplit
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import numpy as np

# generate synthetic time series dataset
X = np.array([[i] for i in range(10)])
y = np.array([i + np.random.randn() for i in range(10)])

# create TimeSeriesSplit object
tscv = TimeSeriesSplit(n_splits=3)

# loop through each split
for train_index, test_index in tscv.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

    # create and fit model
    model = LinearRegression()
    model.fit(X_train, y_train)

    # evaluate model
    y_pred = model.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)

    # print results
    print(f"Train: {train_index}, Test: {test_index}")
    print(f"Test MSE: {mse:.3f}")

Running the example gives an output like:

Train: [0 1 2 3], Test: [4 5]
Test MSE: 3.795
Train: [0 1 2 3 4 5], Test: [6 7]
Test MSE: 1.506
Train: [0 1 2 3 4 5 6 7], Test: [8 9]
Test MSE: 1.719

The steps are as follows:

  1. First, a synthetic time series dataset is generated using NumPy. The X variable represents the time steps, and y represents the corresponding values with added random noise.

  2. A TimeSeriesSplit object is created with n_splits=3, which means the data will be split into three train and test sets.

  3. The code then loops through each split using the split() method of the TimeSeriesSplit object. For each iteration, the train_index and test_index are used to split X and y into train and test sets.

  4. A linear regression model is created and fit on the train set using fit(). The model’s performance is then evaluated on the test set using predict() and mean_squared_error().

  5. Finally, the train and test indices and the test mean squared error are printed for each split.

This example demonstrates how to use TimeSeriesSplit for cross-validation on time series data. By preserving the temporal order of the data during splitting, it provides a more reliable evaluation of the model’s performance compared to traditional cross-validation methods.



See Also