SKLearner Home | About | Contact | Examples

Configure KNeighborsRegressor "n_neighbors" Parameter

The n_neighbors parameter in scikit-learn’s KNeighborsRegressor controls the number of nearest neighbors considered when making predictions.

K-Nearest Neighbors (KNN) regression predicts the value of a point based on the average values of its k nearest neighbors. The n_neighbors parameter specifies how many neighbors are considered in the prediction process.

Increasing n_neighbors smooths the model predictions, reducing variance but potentially increasing bias. Conversely, a smaller n_neighbors captures more detail but may lead to overfitting.

The default value for n_neighbors is 5. Common values range from 1 to 20, depending on the dataset’s size and noise.

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error

# Generate synthetic dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with different n_neighbors values
n_neighbors_values = [1, 5, 10, 20]
mean_squared_errors = []

for n in n_neighbors_values:
    knn = KNeighborsRegressor(n_neighbors=n)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    mean_squared_errors.append(mse)
    print(f"n_neighbors={n}, Mean Squared Error: {mse:.3f}")

Running the example gives an output like:

n_neighbors=1, Mean Squared Error: 7296.535
n_neighbors=5, Mean Squared Error: 3728.344
n_neighbors=10, Mean Squared Error: 3606.405
n_neighbors=20, Mean Squared Error: 4244.133

The key steps in this example are:

  1. Generate a synthetic regression dataset using make_regression.
  2. Split the data into training and test sets using train_test_split.
  3. Train KNeighborsRegressor models with different n_neighbors values.
  4. Evaluate the mean squared error of each model on the test set.

Some tips and heuristics for setting n_neighbors:

Issues to consider:



See Also