Skip main navigation

New offer! Get 30% off one whole year of Unlimited learning. Subscribe for just £249.99 £174.99. New subscribers only. T&Cs apply

Find out more

Visualisation in Python

A brief guide to data and model visualisation in Python.

As we’ve seen, good data visualisation is an important tool in your analysis, and for communicating your results.

This article gives a few practical tips for data and model visualisation in Python.


There are lots of useful tools for this in Python, not least Matplotlib as we saw in week 1. Matplotlib has many many more options and plot types than we have time to discuss here. For a good overview see the gallery of Matplotlib examples on their website.

Confusion matrices

We’ve seen in the preceding videos that confusion matrices are useful ways to visualise the performance of machine learning classifiers.

But how do we make them in Python? One way is to use scikit-learn and Matplotlib. Scikit-learn has two useful functions:

  • confusion_matrix, which makes the matrix itself
  • ConfusionMatrixDisplay which uses Matplotlib to plot and display it.

To demonstrate we can go back to our random forest example using the Iris data:

# Load the data
from sklearn.datasets import load_iris

iris = load_iris()

X =
y =

# Split in to training and test sets
from sklearn.model_selection import train_test_split
Xtrain, Xtest, y_train, y_test = train_test_split(X, y)

# Train the random forest and make predictions on the test set
from sklearn.ensemble import RandomForestClassifier

y_predicted = model.predict(Xtest)

Now we have a set of predictions for the test data, we can compare with the true results and make a confusion matrix:

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
cm = confusion_matrix(y_test,y_predicted)
[[15 0 0]
[ 0 12 1]
[ 0 0 10]]

As you see from the print out, the function confusion_matrix just gets the numbers for the matrix, it doesn’t make a plot.

To do that we need to use ConfusionMatrixDisplay as follows:

from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
disp = ConfusionMatrixDisplay(confusion_matrix=cm,

Decision boundary

Another nice way to visualise the output of a random forest classifier (and other types of classifiers) is to plot the decision boundary. We can do this in scikit-learn but we first need to cut the features in the dataset down to two dimensions so we can plot it.

After training the reduced dataset, the next step is to evaluate the predictions of the model on a grid of appropriate values, which can then be plotted by colouring the background according to the predicted classes in 2D space.

The code below demonstrates how to do this. Try it out for yourself and try and figure out what it is doing using the comments.

import numpy as np
import matplotlib.pyplot as plt

# load the data and select two features only
from sklearn.datasets import load_iris
iris = load_iris()
sdata =[:,[0,3]] # we have picked sepal length (0) and petal width (3)

# make a grid of values for predicting and plotting based on the data
xmax = 8.2
xmin = 3.8
ymin = -0.2
ymax = 2.8

feature_1, feature_2 = np.meshgrid(
np.linspace(xmin, xmax),
np.linspace(ymin, ymax))

grid = np.vstack([feature_1.ravel(), feature_2.ravel()]).T

# set up the model, train it, then make predictions on the grid of values
from sklearn.ensemble import RandomForestClassifier

y_pred = np.reshape(forest_model.predict(grid), feature_1.shape)

# make and display the plot
from sklearn.inspection import DecisionBoundaryDisplay #NOTE NEW IN VERSION 1.1
display = DecisionBoundaryDisplay(xx0=feature_1, xx1=feature_2, response=y_pred)


display.ax_.scatter(sdata[:, 0], sdata[:, 1],, edgecolor="black")

plt.xlabel('sepal length (cm)')
plt.ylabel('petal width (cm)')

A scatter plot of Iris petal width in cm against sepal length in cm, colour coded according to species, purple, green, and yellow. The data appears to be roughly in three clusters, with the purple species positioned bottom left, green species in the centre, and yellow species top right. The background of the axes has also been coloured to show the decision boundary found using the random forest classifier, which indicates which class the algorithm would predict for every point within the axes. Most, but not all of the data points fall within the correct classification boundary.

The plot above shows the output. The original data points and the background have been coloured in the same way so you can see clearly where the model would misclassify data points.


Though you can achieve a lot with Matplotlib, some might find the aesthetic appeal of the default plots a bit basic, and to customise them can involve a bit of tedious extra coding. This is especially likely to be true if you are used to using the statistical programming language R, or RStudio.

Seaborn is an additional package that works on top of Matplotlib and tries to both simplify the user interface and improve the appearance of standard Matplotlib plots. A particular strength of Seaborn vs Matplotlib is that it can easily handle data stored in Pandas ‘dataframes’ format, in a similar way to R.

You can think of Matplotlib as a plugin to Python, and Seaborn as a plugin to Matplotlib. You might already have it installed along with Anaconda, but as usual installing it is easy:

 pip install seaborn

You can also look up more detailed installation instructions from the library webpages themselves.

Since Seaborn is based on Matplotlib, the code to produce a simple plot can be identical, but the resulting plot appears different. For example using Matplotlib:

import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-5, 5, 1000)
y1 = np.cos(x)
y2 = np.sin(x)

plt.plot(x, y1,label='cos')
plt.plot(x, y2,label='sin')

A set of axes plotting the cosine function in blue, and the sine function in orange. the background of the axes is plain white.

Using Seaborn, we just need to import the library and then use the set function, and the default plotting behaviour will be changed.

import seaborn
x = np.linspace(-5, 5, 1000)
y1 = np.cos(x)
y2 = np.sin(x)

plt.plot(x, y1,label='cos')
plt.plot(x, y2,label='sin')

A set of axes plotting the cosine function in blue, and the sine function in orange. the background of the axes is white gridlines on a grey background.

Of course, its a matter of opinion whether this default style change is an improvement, but Seaborn has other styles you can use, and if you want you can even define your own plot styles.

As with Matplotlib, there’s lots more you can do with Seaborn than we have time to go into here, so to learn more see the Seaborn example gallery.

This article is from the free online

Machine Learning for Image Data

Created by
FutureLearn - Learning For Life

Reach your personal and professional goals

Unlock access to hundreds of expert online courses and degrees from top universities and educators to gain accredited qualifications and professional CV-building certificates.

Join over 18 million learners to launch, switch or build upon your career, all at your own pace, across a wide range of topic areas.

Start Learning now