New offer! Get 30% off one whole year of Unlimited learning. Subscribe for just £249.99 £174.99. New subscribers only. T&Cs apply

K-nearest neighbour classification (KNN)

A brief description of the K-nearest classifier

In the videos we saw an example which used K-nearest neighbour (KNN) classification on data recording the size of Iris petals and sepals, to identify the species.

What does KNN do?

K-nearest neighbour or KNN classification works on the simple idea that objects of a particular class (e.g. a species of flower) that can be described using one or more numbers (e.g. width and length of petals and sepals) are ‘near’ to each other, or at least are nearer to each other than objects of a different class (e.g. a different species of flower).

To use it you need a set of so-called training data, with known classes or categories attached to each piece of data. Then, you can take a new piece of data with unknown class and use KNN with the training data to predict which class the new data belongs to.

Scatter plot of Iris petal width versus sepal length showing the decision boundary found using the K-nearest neighbour algorithm with K=7
If your data is two-dimensional, you can do this for a whole range of values and find the class predictions for each. By plotting and colouring the regions predicted to belong to each class you can then visualise the decision boundary, where the decision to assign new datapoints to one class flips to another class.

The KNN Algorithm

Before you use KNN, you first need to pick a value for K. This is just a (usually fairly small) whole number or integer value. There’s a bit more on choosing K below, but for now let’s say we choose K=3.
Then what KNN does, for each new piece of data, is look through the training dataset, find the K (in this case 3) pieces of data nearest the new data point, and then look which class these 3 or K pieces of data belong to. The prediction for the class of the new datapoint is just the most commonly seen class among the K neighbours.
In the sketch below the new data point (marked ?) has two points of class B, and one of class A among its three nearest neighbours. So it would be predicted by KNN to be of class B.
If there are just two classes, and you pick an odd number for K (e.g. like 3), this prediction is always clear cut. There will always be more neighbours belonging to one class than the other. Often all the neighbours will belong to the same class.
Sometimes, if you have more than two classes, the prediction might not be clear cut. For example, if we have three classes and use K=3, we might find a point where each of the three nearest neighbours belongs to different classes. There’s no one way to deal with such a tie however, and different software implementations of KNN will treat this situation differently.
Of course, if you set your classifier up to return a set of probabilities of class membership, rather than a straight prediction, any cases where there is ambiguity will be highlighted.

How do we measure distance?

We all have an intuitive sense of how close two objects are to one another, and how we might measure it. In one-dimension this is two points on a line, in two-dimensions these points would be on a flat plane, while in three dimensions the two points are in 3D space.
The most common way to measure the distance between points is by treating the data points as Cartesian coordinates and calculating the Euclidian distance between them. If we have two points (a) and (b) with coordinates ((a_1,a_2)) and ((b_1,b_2)) this is just:

[D = sqrt{(a_1 – b_1)^2 + (a_2 – b_2)^2}]

Its difficult to visualise the distance between points in more than three dimensions, but mathematically the principle is the same. It’s just the square root of the sum of squared differences between each of the coordinate components. So for (n) components:

[D = sqrt{(a_1 – b_1)^2 + (a_2 – b_2)^2+ … (a_n-b_n)^2}]

Different implementations of KNN might use slightly different distance measures or metrics, but the Euclidian distance is the most commonly used.

Because of way KNN works its usually best to normalise your dataset before using it so that each dimension of your dataset varies over a similar range. There’s a bit more on normalisation in the article on regularisation later in the course.

How to choose K?

The choice of K is important and can yield very different results. As already mentioned, setting an odd number for K ensures there are no ties at the boundary between two classes. More generally, a small value of K might lead to overfitting where your predictions are too sensitive to noise in the training set. The most extreme case of this is setting K=1, where each prediction is just the class of the nearest value in the training set.

On the other hand, a large value for K might result in underfitting, as the classes of datapoints further and further away from new datapoints are considered when predicting the class of those new datapoints. Often there is an optimum value for K that can only be found by repeated use of KNN with different values of K, using the techniques of model evaluation and validation we will see later in the course.

Pros and cons

KNN is a simple algorithm that, unlike most of the machine learning algorithms we will see is non-parametric. This means there are no parameters that need to be learned in a training step before we use it. All we need to do is obtain a training dataset, choose a value of K, and KNN is ready to make predictions. This can be an advantage since for many methods initial training of the model can be time consuming.

In contrast however, the calculations required, and so computational time taken to make predictions is more intensive than many other machine learning algorithms, since every new prediction requires use of the entire training dataset each time. With larger and larger datasets and more and more dimensions this can rapidly become very time consuming, so KNN does not scale well to large datasets with lots of features. In some cases however, there may be ways to reduce this problem, via dimensionality reduction (e.g. principal component analysis or PCA), or via dataset reduction (e.g. condensed nearest neighbour or CNN).

Practical

In the practical at the end of the week we will look at a KNN example using the Iris dataset within Python Scikit-Learn.