SKLearner Home | About | Contact | Examples

Configure KNeighborsClassifier "algorithm" Parameter

The algorithm parameter in scikit-learn’s KNeighborsClassifier determines the algorithm used to compute the nearest neighbors.

K-Nearest Neighbors (KNN) is a simple non-parametric classification algorithm. The choice of algorithm for finding nearest neighbors can significantly impact the performance and computational efficiency of KNN.

The algorithm parameter can be set to 'auto', 'ball_tree', 'kd_tree', or 'brute'. The default value is 'auto', which attempts to choose the most appropriate algorithm based on the values passed to fit method.

In practice, 'brute' is often used for small datasets, while 'kd_tree' or 'ball_tree' are used for larger ones. The 'auto' option is a good choice when the optimal algorithm is unknown.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import time

# Generate synthetic dataset
X, y = make_classification(n_samples=10000, n_classes=5, n_features=20, n_informative=10,
                           n_redundant=5, random_state=42)

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train with different algorithm values
algorithms = ['auto', 'ball_tree', 'kd_tree', 'brute']
results = []

for algo in algorithms:
    start = time.time()
    knn = KNeighborsClassifier(n_neighbors=5, algorithm=algo)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    end = time.time()
    results.append((algo, accuracy, end-start))

# Print results
print(f"{'Algorithm':<10} {'Accuracy':<10} {'Time (s)':<10}")
print("-"*30)
for algo, acc, t in results:
    print(f"{algo:<10} {acc:<10.3f} {t:<10.3f}")

Running the example gives an output like:

Algorithm  Accuracy   Time (s)
------------------------------
auto       0.857      0.062
ball_tree  0.857      0.403
kd_tree    0.857      0.308
brute      0.857      0.039

The key steps in this example are:

  1. Generate a synthetic multiclass classification dataset
  2. Split the data into train and test sets
  3. Train KNeighborsClassifier models with different algorithm values
  4. Evaluate the accuracy and training time for each model

Some tips and heuristics for setting algorithm:

Issues to consider:



See Also