[PYTHON] Draw hierarchical axis labels with matplotlib + pandas

If you plot a Matplotlib data frame with a hierarchical index as it is, the axis labels will be displayed as tuples, which is a bit disappointing.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

a = pd.DataFrame(np.random.random([6, 2]),
                 index=pd.MultiIndex.from_product([['group1', 'group2'],['item1', 'item2', 'item3']]),
                 columns=['data1', 'data2'])

a.plot.bar()

img1.png

I wanted to add an axis label that makes it easy to see that the data has a hierarchical structure, but for the time being, it seems that it can be done by the following method.

def set_hierarchical_xlabels(index, ax=None,
                             bar_xmargin=0.1, #Margins on the left and right ends of the line, X-axis scale
                             bar_yinterval=0.1, #Relative value with the vertical spacing of the line and the length of the Y axis as 1?
                            ):
    from itertools import groupby
    from matplotlib.lines import Line2D

    ax = ax or plt.gca()

    assert isinstance(index, pd.MultiIndex)
    labels = ax.set_xticklabels([s for *_, s in index])
    for lb in labels:
        lb.set_rotation(0)

    transform = ax.get_xaxis_transform()

    for i in range(1, len(index.codes)):
        xpos0 = -0.5 #Coordinates on the left side of the target group
        for (*_, code), codes_iter in groupby(zip(*index.codes[:-i])):
            xpos1 = xpos0 + sum(1 for _ in codes_iter) #Coordinates on the right side of the target group
            ax.text((xpos0+xpos1)/2, (bar_yinterval * (-i-0.1)),
                    index.levels[-i-1][code],
                    transform=transform,
                    ha="center", va="top")
            ax.add_line(Line2D([xpos0+bar_xmargin, xpos1-bar_xmargin],
                               [bar_yinterval * -i]*2,
                               transform=transform,
                               color="k", clip_on=False))
            xpos0 = xpos1

a.plot.bar()
set_hierarchical_xlabels(a.index)

img2.png

I think that the detailed design will change depending on the purpose and taste, so it is OK if you modify it appropriately referring to the above. You can also plot with MultiIndex of 3 or more layers.

Recommended Posts

Draw hierarchical axis labels with matplotlib + pandas
2-axis plot with Matplotlib
Draw Japanese with matplotlib on Ubuntu
Draw a loose graph with matplotlib
Versatile data plotting with pandas + matplotlib
Draw a graph with pandas + XlsxWriter
Easy to draw graphs with matplotlib
Draw Lyapunov Fractal with Python, matplotlib
Draw a flat surface with a matplotlib 3d graph
Draw a graph with Japanese labels in Jupyter
How to draw a 2-axis graph with pyplot
Read Python csv data with Pandas ⇒ Graph with Matplotlib
Draw a graph by processing with Pandas groupby
[Python] How to draw multiple graphs with Matplotlib
Implement "Data Visualization Design # 3" with pandas and matplotlib
Analyze Apache access logs with Pandas and Matplotlib
Animation with matplotlib
Japanese with matplotlib
Draw a graph with matplotlib from a csv file
Animation with matplotlib
Histogram with matplotlib
Animate with matplotlib
matplotlib: Replace the axis itself with another one.
[Python] How to draw a line graph with Matplotlib
Draw an Earth-like flow animation with matplotlib and cartopy
Forcibly draw something like a flowchart with Python, matplotlib
[Python] How to draw a scatter plot with Matplotlib
Quickly visualize with Pandas
Bootstrap sampling with Pandas
Processing datasets with pandas (2)
Merge datasets with pandas
Learn Pandas with Cheminformatics
Heatmap with Python + matplotlib
Band graph with matplotlib
Data visualization with pandas
Learn with Cheminformatics Matplotlib
Data manipulation with Pandas!
Real-time drawing with matplotlib
Shuffle data with pandas
Various colorbars with Matplotlib
3D plot with matplotlib
[Numpy / pandas / matplotlib Exercise 01]
Adjust axes with matplotlib
Study math with Python: Draw a sympy (scipy) graph with matplotlib
Reformat the timeline of the pandas time series plot with matplotlib