Skip to main content

K Nearest Neighbors

Premium

K nearest neighbors is a supervised learning algorithm that can be used for regression or classification. It involves predicting new data points by finding data points in the training set that are closest to the new one. It is a simple non-parametric algorithm.

A non-parametric algorithm is a type of machine learning algorithm that does not assume a fixed form or a predefined number of parameters for the underlying model. Unlike parametric algorithms, which rely on a specific model structure (e.g., linear or polynomial) and require estimating a set of parameters, non-parametric algorithms are more flexible and can adapt to the complexity of the data without making strong assumptions about its distribution. This adaptability makes non-parametric algorithms particularly useful when dealing with datasets where the relationship between input features and the target variable is unknown or highly complex.

Parametric Algorithms:

  • Linear Regression: Assumes a linear relationship between input features and the target variable.
  • Logistic Regression: Models the probability of a binary outcome using a logistic function.
  • Support Vector Machine (with linear kernel): Assumes a linear decision boundary between classes.
  • Naive Bayes: Assumes feature independence and uses a probabilistic approach based on Bayes' theorem.

Non-Parametric Algorithms:

  • K-Nearest Neighbors (KNN): Makes predictions based on the closest data points in the training set without assuming any underlying model structure.
  • Decision Trees: Splits data into subsets based on feature values, creating a tree structure without predefined parameters.
  • Support Vector Machine (with non-linear kernel): Uses a non-linear kernel to map data into a higher-dimensional space without assuming a specific form for the decision boundary.

By the end of this lesson, you'll have a better idea of how to answer the following commonly asked questions:

  1. What are the key advantages and disadvantages of the KNN algorithm?
  2. How do you choose the optimal value of K?
  3. What are some common distance metrics used in KNN, and when would you use one over the other?

We'll provide the answers at the end of this lesson.

Overview

HeaderHeaderHeaderHeader
Project NameCellCellCell
Project ThemesCellCellCell
BLUFCellCellCell
ContextCellCellCell
ActionCellCellCell
ResultCellCellCell
LessonCellCellCell
KNN
Supervised/UnsupervisedSupervised
InputDataset where each data point is represented by a set of features or attributes. These features could be continuous or categorical, but they must be numerical in nature.
OutputA category in the case of classification, or a numerical value in the case of regression.
Use casesClassification or regression for tabular, text, or image data. Can also be used to find similar data points.
LevelConcepts to master
JuniorAble to understand the implementation of the KNN algorithm. Familiar with the euclidean distance. Able to use a third party library to use it in practice.
Mid-levelCan gauge tradeoffs of using KNN in practice. Knows how to experiment and set the optimal value of K. Able to discuss different distance metrics.
SeniorAble to explain how KNN is implemented in a real world system. Able to talk about more advanced use cases of nearest neighbors, such as approximate nearest neighbor search. Able to discuss computational complexity

Algorithm

For each new data point (represented by a feature vector) and a chosen hyperparameter K:

  1. Calculate the distance between the new data point and each data point in the training set (also represented as feature vectors)
  2. Sort the distances in ascending order to find the K data points that are the closest (have the shortest distance) to that point.
  3. Output the prediction as the aggregate ground truth label of the K nearest datapoints. In the case of regression, this is the average of these labels. In the case of classification, it is the majority vote.

If you are using KNN for binary classification, it’s generally recommended to set K to an odd number. That way, when doing a majority vote, you will not have a tie between classes.

Different distance metrics can be used for KNN, but the most common metric is the euclidean distance.

Equation for euclidean distance:

d(p,q)=(p1q1)2+(p2q2)2+...+(pnqn)2d(p,q) = \sqrt{(p_1-q_1)^2+(p_2-q_2)^2+...+(p_n-q_n)^2}

Example:

Doing KNN for a single unclassified point by hand:

KNN Classification Example

While KNN focuses on using the nearest neighbors to make predictions, the underlying concept of finding the closest data points extends far beyond just classification or regression tasks. In fact, Nearest Neighbors Search (NNS) forms the backbone of many modern search and recommendation systems. Although the goals differ—KNN uses neighbors to assign labels or values, while NNS returns the data points themselves—the principle of locating the most similar items remains the same.

NNS is used extensively in search and recommendation systems, particularly when working with embeddings. While similar in concept to KNN—both involve finding the closest data points to an input—there is no ground truth label involved in search systems. Instead, the data points themselves are returned as the results.

More advanced implementations of NNS are used in practice to speed up search systems. These often rely on approximate nearest neighbors algorithms, such as FAISS (Facebook AI Similarity Search), Annoy, and HNSW (Hierarchical Navigable Small World). These implementations can be provided within databases optimized for search, like Pinecone.

Nearest Neighbor Search

Pseudocode

From the outlined algorithm, we can construct a pseudocode for KNN as follows:

Python
import numpy as np """ KNN algo: have a classification dataset with (X, y) at inference time, take majority vote of K nearest neighbors """ def euclidean_distance(x, y): return np.sqrt((x - y).T.dot(x - y)) def knn(X_train, y_train, X, k): """Assume binary classification""" distances = [] for X_i in X_train: distance = euclidean_distance(X_i, X) distances.append(distance) top_k_indices = np.argsort(distances)[::-1][:k] top_k_labels = [y[idx].item() for idx in top_k_indices] if sum(top_k_labels) > k / 2: return 1 else: return 0 if __name__ == "__main__": N = 100 D = 4 X = np.random.randn(100, 4) y = np.random.randint(0, 2, (N, 1)) to_predict = np.random.randn( 4, ) predicted_label = knn(X, y, to_predict, k=3) print(predicted_label)

Evaluation

View our lesson on supervised model evaluation to find out what metrics are used to evaluate the model.

Limitations

  • Curse of dimensionality. As the number of features increases, it becomes more computationally expensive to compute distances, and the distance between points becomes less meaningful.
  • Slow inference time. KNN calculates the distance from the query point to all training points, making it inefficient for large datasets or high-dimensional data.
  • Memory intensive. KNN stores the entire training dataset, which requires significant memory resources.
  • Dependence on distance metrics. KNN may be biased toward features with larger numerical ranges if the features are on different scales, requiring feature scaling (e.g., normalization or standardization).
  • Sensitivity to noise and outliers. Noisy data points or outliers can significantly reduce accuracy, as all training instances directly influence predictions.

Employing KNN

Here’s how KNN can be used in practice:

Python
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score, classification_report, confusion_matrix # Load the dataset (Iris dataset for example purposes) data = load_iris() X, y = data.data, data.target # Split the data into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Standardize the features (important for KNN) scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # Initialize the KNN classifier with K=3 (you can adjust this) knn = KNeighborsClassifier(n_neighbors=3) # Train the model knn.fit(X_train, y_train) # Make predictions on the test set y_pred = knn.predict(X_test)

Common Questions

Q: What are the key advantages and disadvantages of the KNN algorithm?

A:

AdvantagesDisadvantages
Non-parametric: It doesn’t make underlying assumptions about the data distribution.Computationally expensive: Since KNN computes the distance from the query point to all points in the training set, it can be slow for large datasets.
Simple to implement: The algorithm is straightforward to set up and use.Memory intensive: KNN stores all training data, which requires significant memory resources.
No training phase: KNN doesn’t require a separate training phase, as it uses the entire dataset for predictions.Performance degradation: KNN’s performance can degrade with irrelevant features, noise, or outliers in the data.
Sensitive to feature scales: Distance metrics used in KNN are sensitive to feature scales, so normalization or standardization is a necessary preprocessing step.

Q: How do you choose the optimal value of K?

A: You can choose to focus on a few aspects:

  • Model performance - treat K as a hyperparameter, and choose the value that gives the highest performance on validation data. Alternatively, you can use cross validation.
  • Consider the bias variance tradeoff. Smaller values of K lead to low training error, but tend to overfit and generalize poorly. Larger values of K lead to less model variance, but can also lead to underfitting.

Q: What are some common distance metrics used in KNN, and when would you use one over the other?

A:

Euclidean distance: Best suited for continuous variables where straight-line distance is a meaningful measure. It’s the most commonly used metric when features are of similar scale.

Manhattan distance: Works well in grid-like patterns (e.g., city blocks) or when dealing with high-dimensional spaces where individual component differences are more meaningful than straight-line distance.

Cosine similarity: Commonly used for text data and high-dimensional data where orientation, not magnitude, matters. It measures the cosine of the angle between vectors, treating them as points on a unit sphere.

Hamming distance: Hamming distance is best used when comparing binary, categorical, or discrete data of equal length, where the focus is on counting exact mismatches between data points, such as in error detection, sequence alignment, or classification tasks involving binary feature vectors. It's ideal for applications like comparing strings, identifying bit differences in binary data, or analyzing categorical data encoded in a way that allows direct position-wise comparison.