import seaborn as snsflights = sns.load_dataset("flights")# Plot each year's time series in its own facetg = sns.relplot( data=flights, x="month", y="passengers", col="year", hue="year", kind="line", palette="crest", linewidth=4, zorder=5, col_wrap=5, height=2, aspect=1.5, legend=False,)# Iterate over each subplot to customize furtherfor year, ax in g.axes_dict.items():# Plot every year's time series in the background sns.lineplot( data=flights, x="month", y="passengers", units="year", estimator=None, color=".7", linewidth=1, ax=ax, )g.tight_layout()
Step 1: read and pre-transform the data
Let’s read in dummy data and preprocess the datetimes, so that we can easily plot:
from seaborn import load_datasetimport matplotlib.pyplot as plt import pandas as pd import numpy as npdf = load_dataset('flights')df['m'] = df['month'].cat.codes.apply(lambda x: x+1)df = df.sort_values(by=['year','m'])
Step 2: create a small-multiples skeleton
We want to create a small multiples chart, one chart per year.
Let’s start from defining the rows and columns in the grid manually. We can use plt.subplots() and then iterate over the axes and the sub-dataframe for each axes. For this, we’re going to zip the axes and the keys used to sub-sample the dataframe.
Bummer: zip will stop evaluating the moment the shorter list is exhausted.
However, we can use a zip_longest function from itertools which will iterate until it exhausts the longer list. If we don’t provide a fillvalue to it, it will produce None once the shorter iterable has been exhausted. E.g.
from itertools import zip_longest
for first, second in zip_longest(range(3), range(5)):print(first, second)
0 0
1 1
2 2
None 3
None 4
We need to generate more or equal axes than the number of charts we effectively expect (i.e. in our example, the number of available years). Then whenever we get None for the year, we can remove the axis from the figure:
Only the last adjustment: until now we have defined the number of charts by hand. Let’s only define the number of columns and let the script handle the relevant number of rows. Also, let’s not hardcode our column name, but rather store it in a variable:
col ='year'num_cols =5num_rows = (df[col].nunique() // num_cols) +1f, axes = plt.subplots(ncols=num_cols, nrows=num_rows, sharex=True, sharey=True)for chosen, ax in zip_longest(df[col].unique(), axes.ravel()): if chosen isnotNone: tmp_df = df[df[col]==chosen] ax.plot(tmp_df['m'], tmp_df['passengers'], )else: ax.remove()
Step 3: add missing x-labels
Another problem has arisen. Because we are using sharex, sharey, the labels on the bottom line of charts are missing. Let’s find out which charts are on the bottom. We will create a bool table where we define 1 for the charts where we want to see the xlabels:
However, I don’t like it for the readability. Also, maybe I don’t want to have this option hardcoded - it will be easier to parametrize it if I add the labels in a second loop. Here I can use a simple zip, because the is_xlabeled and axes have the same shapes.
For this, let’s reshape our data. It’s going to be easier if we take advantage of the long-wide conversion and create a pivot table containing our data. Here we don’t need to do any aggregation, but this would be the first step in a data prep pipeline.
Let’s move the styling outside of the loop where we plotting the dashboard. Maybe in the future we want to try out various colors and build up a legend basing on those colors. Dictionary is great for this:
Now it looks like we have separated the data and plotting parts. We can wrap our code in a function. We define default styles inside the function, and return f and axes objects for easy modification further on.
With such a structure, we can easily modify the look and feel of our small-multiples dashboard:
f, axes = small_multiples(df=df, col='year', x_col='m', y_col='passengers', num_cols=5, style_selected={'color':'limegreen', 'lw':3},)f.suptitle('Flights over the years')
Text(0.5, 0.98, 'Flights over the years')
However, there is something we are missing. We cannot easily scale the size of the dashboard. For this, let’s declare another variable dashboard_props, which we will pass to plt.subplots():