Skip main navigation

New offer! Get 30% off one whole year of Unlimited learning. Subscribe for just £249.99 £174.99. New subscribers only. T&Cs apply

Find out more

Naive Bayes

An article describing the Naive Bayes algorithm.
A set of unlabelled axes with two clusters of data points labelled A in blue, and B in red, with concentric ellipses centred on the middle of each cluster.

As we saw in the previous video, Naive Bayes is a simple supervised learning classification method.

In this article, we will review some of the details of the algorithm, before going through a simple worked example.

The Naive Bayes algorithm

Naive Bayes is designed for numerical data for which you already have some classification data. A good example of this is the Iris dataset we have seen previously in the course. This gives width and length dimensions for petals and sepals of Iris flowers, along with the species of each particular flower. So in this case our classes are the different flower species.

The main assumption for Naive Bayes is that each data feature is distributed in a Normal (often referred to as ‘Gaussian’) distribution about some cluster centre corresponding to its class. The Normal distribution is what is known as a probability density function. You may have encountered these before, but if not here’s a quick explanation.

Probability density functions

Probability density functions are a way of describing mathematically how likely something is to take a specific value within a continuous range of values. They can be visualised for a single variable as a 2D plot like the one shown below for the Normal distribution.

Every value along the curve represents the probabilty between zero and one that the variable takes that value. So the higher the curve, the more likely the variable is to take that value. In the case of the Normal distribution, the curve gets closer and closer to a probability of zero for further we get from the centre, without ever actually reaching zero.

Because we are sure the variable must take some value along the curve, one key feature of a probabilty density function is that the total area under the curve must add up to exactly one.

A visualisation of a Normal distribution in one-dimension looks like a bell-shaped curve, and represents the probability that any given point takes a particular value along the curve. So we are much more likely to see values near the centre of the curve, than those further away, while values far away from the centre are very unlikely to be seen at all. The centre of the curve is the mean value, while the distance to roughly halfway down each side of the bell shape is the standard deviation.

The Normal distribution in Naive Bayes

We talked over some of the details in the video, but the key point is that Naïve Bayes finds the mean and standard deviation for the distributions of every feature for every class in your dataset. Using Bayes theorem, and a bit of algebra, we can then assign a probability that any new data point belongs to a given class. If we want to predict a class for the new data point we can assign it to the class with the highest probability.

In two dimensions the normal distributions might look like the figure below, with points scattered around the centre of two clusters we have labelled A and B.

Naive Bayes in Scikit-Learn

To demonstrate how to use Naive Bayes in Scikit-Learn we can import and split the Iris data in exactly the same way as in previous examples:

from sklearn.datasets import load_iris

iris = load_iris()

X = iris.data
y = iris.target

from sklearn.model_selection import train_test_split
Xtrain, Xtest, y_train, y_test = train_test_split(X, y)

Then to set up the model we just need to import and instantiate the GaussianNB() class:

from sklearn.naive_bayes import GaussianNB
model = GaussianNB()

Then we can train and evaluate the model as follows:

model.fit(Xtrain,y_train)

from sklearn.metrics import accuracy_score
y_predicted = model.predict(Xtest)
score = accuracy_score(y_test,y_predicted)
print(score)

Optional

If you have time, try this code out, and try plotting the misclassified examples as we did with the K-nearest neighbour example in the practical in Week 1.

This article is from the free online

Machine Learning for Image Data

Created by
FutureLearn - Learning For Life

Reach your personal and professional goals

Unlock access to hundreds of expert online courses and degrees from top universities and educators to gain accredited qualifications and professional CV-building certificates.

Join over 18 million learners to launch, switch or build upon your career, all at your own pace, across a wide range of topic areas.

Start Learning now