import seaborn as sns
= sns.load_dataset("flights")
flights
# Plot each year's time series in its own facet
= sns.relplot(
g =flights,
data="month", y="passengers", col="year", hue="year",
x="line", palette="crest", linewidth=4, zorder=5,
kind=5, height=2, aspect=1.5, legend=False,
col_wrap
)
# Iterate over each subplot to customize further
for year, ax in g.axes_dict.items():
# Plot every year's time series in the background
sns.lineplot(=flights, x="month", y="passengers", units="year",
data=None, color=".7", linewidth=1, ax=ax,
estimator
)
g.tight_layout()
In this tutorial:
- how to loop over axes and data
- how to display axis labels only in some axes
- how to store styles in dictionaries
- how to adjust tick frequency
- how to create a custom legend
- how to pass properties inside a charting function
Small multiples is an extremely helpful technique in data visualization. See below how a small multiples chart can be done using seaborn’s relplot
. How could we make such a chart from scratch in MPL, adjusted to our needs?
Step 1: read and pre-transform the data
Let’s read in dummy data and preprocess the datetimes, so that we can easily plot:
from seaborn import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
= load_dataset('flights')
df 'm'] = df['month'].cat.codes.apply(lambda x: x+1)
df[= df.sort_values(by=['year','m']) df
Step 2: create a small-multiples skeleton
We want to create a small multiples chart, one chart per year.
Let’s start from defining the rows and columns in the grid manually. We can use plt.subplots()
and then iterate over the axes and the sub-dataframe for each axes. For this, we’re going to zip
the axes and the keys used to sub-sample the dataframe.
= plt.subplots(ncols=5, nrows=3, sharex=True, sharey=True)
f, axes
for chosen, ax in zip(df['year'].unique(), axes.ravel()):
= df[df['year']==chosen]
tmp_df 'm'], tmp_df['passengers'], ) ax.plot(tmp_df[
Bummer: zip
will stop evaluating the moment the shorter list is exhausted.
However, we can use a zip_longest
function from itertools
which will iterate until it exhausts the longer list. If we don’t provide a fillvalue
to it, it will produce None
once the shorter iterable has been exhausted. E.g.
from itertools import zip_longest
for first, second in zip_longest(range(3), range(5)):
print(first, second)
0 0
1 1
2 2
None 3
None 4
We need to generate more or equal axes than the number of charts we effectively expect (i.e. in our example, the number of available years). Then whenever we get None
for the year, we can remove the axis from the figure:
= plt.subplots(ncols=5, nrows=3, sharex=True, sharey=True)
f, axes
for chosen, ax in zip_longest(df['year'].unique(), axes.ravel()):
if chosen is not None:
= df[df['year']==chosen]
tmp_df 'm'], tmp_df['passengers'], )
ax.plot(tmp_df[else:
# ax.axis('off')
ax.remove()
Looks good!
Only the last adjustment: until now we have defined the number of charts by hand. Let’s only define the number of columns and let the script handle the relevant number of rows. Also, let’s not hardcode our column name, but rather store it in a variable:
= 'year'
col = 5
num_cols
= (df[col].nunique() // num_cols) + 1
num_rows
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes
for chosen, ax in zip_longest(df[col].unique(), axes.ravel()):
if chosen is not None:
= df[df[col]==chosen]
tmp_df 'm'], tmp_df['passengers'], )
ax.plot(tmp_df[else:
ax.remove()
Step 3: add missing x-labels
Another problem has arisen. Because we are using sharex
, sharey
, the labels on the bottom line of charts are missing. Let’s find out which charts are on the bottom. We will create a bool table where we define 1 for the charts where we want to see the xlabels:
= df[col].nunique() % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1 is_xlabeled[
is_xlabeled
array([[0., 0., 0., 0., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 0., 0., 0.]])
Now I can integrate it in my code by adding it to the zip_longest
:
= 'year'
col = 5
num_cols
= df[col].nunique()
num_charts
= (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes
for chosen, ax, xlab in zip_longest(df[col].unique(), axes.ravel(), is_xlabeled.ravel()):
if chosen is not None:
= df[df[col]==chosen]
tmp_df 'm'], tmp_df['passengers'], )
ax.plot(tmp_df[if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axiselse:
ax.remove()
However, I don’t like it for the readability. Also, maybe I don’t want to have this option hardcoded - it will be easier to parametrize it if I add the labels in a second loop. Here I can use a simple zip
, because the is_xlabeled
and axes
have the same shapes.
= 'year'
col = 5
num_cols
= df[col].nunique()
num_charts
= (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes
for chosen, ax in zip_longest(df[col].unique(), axes.ravel()):
if chosen is not None:
= df[df[col]==chosen]
tmp_df 'm'], tmp_df['passengers'], )
ax.plot(tmp_df[else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
Looks good!
Step 4: add the data in the background
For this, let’s reshape our data. It’s going to be easier if we take advantage of the long-wide conversion and create a pivot table containing our data. Here we don’t need to do any aggregation, but this would be the first step in a data prep pipeline.
= df.pivot(index='m', columns=col, values='passengers') df_pivot
df_pivot
year | 1949 | 1950 | 1951 | 1952 | 1953 | 1954 | 1955 | 1956 | 1957 | 1958 | 1959 | 1960 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
m | ||||||||||||
1 | 112 | 115 | 145 | 171 | 196 | 204 | 242 | 284 | 315 | 340 | 360 | 417 |
2 | 118 | 126 | 150 | 180 | 196 | 188 | 233 | 277 | 301 | 318 | 342 | 391 |
3 | 132 | 141 | 178 | 193 | 236 | 235 | 267 | 317 | 356 | 362 | 406 | 419 |
4 | 129 | 135 | 163 | 181 | 235 | 227 | 269 | 313 | 348 | 348 | 396 | 461 |
5 | 121 | 125 | 172 | 183 | 229 | 234 | 270 | 318 | 355 | 363 | 420 | 472 |
6 | 135 | 149 | 178 | 218 | 243 | 264 | 315 | 374 | 422 | 435 | 472 | 535 |
7 | 148 | 170 | 199 | 230 | 264 | 302 | 364 | 413 | 465 | 491 | 548 | 622 |
8 | 148 | 170 | 199 | 242 | 272 | 293 | 347 | 405 | 467 | 505 | 559 | 606 |
9 | 136 | 158 | 184 | 209 | 237 | 259 | 312 | 355 | 404 | 404 | 463 | 508 |
10 | 119 | 133 | 162 | 191 | 211 | 229 | 274 | 306 | 347 | 359 | 407 | 461 |
11 | 104 | 114 | 146 | 172 | 180 | 203 | 237 | 271 | 305 | 310 | 362 | 390 |
12 | 118 | 140 | 166 | 194 | 201 | 229 | 278 | 306 | 336 | 337 | 405 | 432 |
Now I can refactor the code: I can use df_pivot
for getting my selected curve and for getting the background curves:
= 'year'
col = 5
num_cols
= df[col].nunique()
num_charts
= (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes = df.pivot(index='m', columns=col, values='passengers')
df_pivot
for chosen, ax in zip_longest(df[col].unique(), axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df ='orange', lw=2, zorder=99)
ax.plot(tmp_df, color='gray', lw=1, alpha=0.5)
ax.plot(df_pivot, colorelse:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
Let’s move the styling outside of the loop where we plotting the dashboard. Maybe in the future we want to try out various colors and build up a legend basing on those colors. Dictionary is great for this:
= 'year'
col = 5
num_cols
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected = {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
= df[col].nunique()
num_charts
= (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes = df.pivot(index='m', columns=col, values='passengers')
df_pivot
for chosen, ax in zip_longest(df[col].unique(), axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
Also, let’s not use hardcoded values for x,y values and only use df_pivot from now on.
= 'year'
col = 'm'
x_col = 'passengers'
y_col = 5
num_cols
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected = {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
= df.pivot(index=x_col, columns=col, values=y_col)
df_pivot
= len(df_pivot.columns)
num_charts = (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes
for chosen, ax in zip_longest(df_pivot.columns, axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
Now it looks like we have separated the data and plotting parts. We can wrap our code in a function. We define default styles inside the function, and return f
and axes
objects for easy modification further on.
def small_multiples(df, col, x_col, y_col, num_cols,
= None, style_bg = None):
style_selected
if style_selected is None:
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected if style_bg is None:
= {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
= df.pivot(index=x_col, columns=col, values=y_col)
df_pivot
= len(df_pivot.columns)
num_charts = (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)
f, axes
for chosen, ax in zip_longest(df_pivot.columns, axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
return f, axes
With such a structure, we can easily modify the look and feel of our small-multiples dashboard:
= small_multiples(df=df, col='year', x_col='m', y_col='passengers',
f, axes =5,
num_cols={'color':'limegreen', 'lw':3},)
style_selected
'Flights over the years') f.suptitle(
Text(0.5, 0.98, 'Flights over the years')
However, there is something we are missing. We cannot easily scale the size of the dashboard. For this, let’s declare another variable dashboard_props
, which we will pass to plt.subplots()
:
def small_multiples(df, col, x_col, y_col, num_cols,
= None, style_bg = None,
style_selected = None):
dashboard_props
if style_selected is None:
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected if style_bg is None:
= {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
if dashboard_props is None:
= {}
dashboard_props
= df.pivot(index=x_col, columns=col, values=y_col)
df_pivot
= len(df_pivot.columns)
num_charts = (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows,
f, axes =True, sharey=True,
sharex**dashboard_props)
for chosen, ax in zip_longest(df_pivot.columns, axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
return f, axes
Now we can also make x-ticks a bit more human-friendly. With the current structure, it is easy to do without having to modify the base function:
= small_multiples(df=df, col='year', x_col='month', y_col='passengers',
f, axes =5,
num_cols={'color':'limegreen', 'lw':3},
style_selected={'figsize': (10,5)})
dashboard_props
'Flights over the years')
f.suptitle(
from matplotlib.ticker import MultipleLocator
3)) for ax in axes.ravel()]; [ax.xaxis.set_major_locator(MultipleLocator(
Step 5: legend and annotations
Let’s add annotations: at this point there is no information about the actual year!
def small_multiples(df, col, x_col, y_col, num_cols,
= None, style_bg = None,
style_selected = None, text_props = None):
dashboard_props
if style_selected is None:
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected if style_bg is None:
= {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
if dashboard_props is None:
= {}
dashboard_props
= {'x': 0.95, 'y': 0.95}
default_text_props if text_props is None:
= {}
text_props if 'x' not in text_props.keys():
'x'] = default_text_props['x']
text_props[if 'y' not in text_props.keys():
'y'] = default_text_props['y']
text_props[
= df.pivot(index=x_col, columns=col, values=y_col)
df_pivot
= len(df_pivot.columns)
num_charts = (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows,
f, axes =True, sharey=True,
sharex**dashboard_props)
for chosen, ax in zip_longest(df_pivot.columns, axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, =chosen,
ax.text(s=ax.transAxes,
transform='top', ha='right', zorder=1,
va**text_props)
else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
return f, axes
= small_multiples(df=df, col='year', x_col='month', y_col='passengers',
f, axes =5,
num_cols={'color':'limegreen', 'lw':3},
style_selected={'figsize': (10,5)}, text_props={'size':8})
dashboard_props
'Flights over the years')
f.suptitle(
from matplotlib.ticker import MultipleLocator
3)) for ax in axes.ravel()]; [ax.xaxis.set_major_locator(MultipleLocator(
And let’s add a legend:
def small_multiples(df, col, x_col, y_col, num_cols,
= None, style_bg = None,
style_selected = None, text_props = None):
dashboard_props
if style_selected is None:
= {'color': 'orange', 'lw': 2, 'zorder': 99}
style_selected if style_bg is None:
= {'color': 'gray', 'lw': 1, 'alpha': 0.5}
style_bg
if dashboard_props is None:
= {}
dashboard_props
= {'x': 0.95, 'y': 0.95}
default_text_props if text_props is None:
= {}
text_props if 'x' not in text_props.keys():
'x'] = default_text_props['x']
text_props[if 'y' not in text_props.keys():
'y'] = default_text_props['y']
text_props[
= df.pivot(index=x_col, columns=col, values=y_col)
df_pivot
= len(df_pivot.columns)
num_charts = (num_charts // num_cols) + 1
num_rows = num_charts % num_cols
remaining
= np.zeros((num_rows, num_cols))
is_xlabeled -1][0:remaining] = 1
is_xlabeled[-2][remaining:] = 1
is_xlabeled[
= plt.subplots(ncols=num_cols, nrows=num_rows,
f, axes =True, sharey=True,
sharex**dashboard_props)
for chosen, ax in zip_longest(df_pivot.columns, axes.ravel()):
if chosen is not None:
= df_pivot[chosen]
tmp_df **style_selected)
ax.plot(tmp_df, **style_bg)
ax.plot(df_pivot, =chosen,
ax.text(s=ax.transAxes,
transform='top', ha='right', zorder=1,
va**text_props)
else:
ax.remove()
for xlab, ax in zip(is_xlabeled.ravel(), axes.ravel()):
if xlab:
='x', which='major', labelbottom=True)
ax.tick_params(axis
from matplotlib.lines import Line2D
= ['selected year', 'other years']
labels = [Line2D([0], [0], **style_selected), Line2D([0], [0], **style_bg)]
handles
-1]\
axes.ravel()[num_charts=labels, handles=handles,
.legend(labels=[1.25,0.5], loc='center left',
bbox_to_anchor='white', facecolor='white',
edgecolor=0, borderaxespad=0)
borderpad
return f, axes
= small_multiples(df=df, col='year', x_col='month', y_col='passengers',
f, axes =5,
num_cols={'color':'limegreen', 'lw':3},
style_selected={'figsize': (10,5)}, text_props={'size':8})
dashboard_props
'Flights over the years')
f.suptitle(
from matplotlib.ticker import MultipleLocator
3)) for ax in axes.ravel()]; [ax.xaxis.set_major_locator(MultipleLocator(