[PYTHON] Create an Ax generator and draw an infinite graph

Introduction

When doing image recognition on jupyter, you may want to arrange a lot of images and graphs. Display all MNIST teacher images. (If you really do it, you won't come back and I think it will fall)

show_mnist_NG.py


import matplotlib.cm as cm
import matplotlib.pyplot as plt
import keras

mnist = keras.datasets.mnist.load_data()
(X_train, t_train), (X_test, t_test) = mnist

f, ax = plt.subplots(3000,20, figsize=(20,3000))
axes = ax.reshape((-1,))
for x, ax in zip(X_train, axes):
    ax.imshow(x, cmap=cm.gray)
    ax.axis('equal')
    ax.axis('off')
plt.show()

However, it became troublesome to call matplotlib every time I drew a graph or image. Therefore, I created a mechanism to display it without permission.

matplotlib.pyplot.subplots()

Use subplots () to arrange images and graphs.

subplots_rowcol.py


import matplotlib.pyplot as plt

nrows = 3
ncols = 5
f, axes = plt.subplots(nrows, ncols, squeeze=False, figsize=(ncols*2.0, nrows*2.0))
for i in range(13):
    r = i // ncols
    c = i % ncols
    axes[r,c].plot(range(10), range(r+c,10+r+c))  #Appropriate graph
plt.show()

But in subplots (),

--Not displayed until you call matplotlib.pyplot.show (). If there are many, you will be kept waiting. --The number of graphs and images must be known in advance. It is inconvenient if the number of graphs changes in the program. --You have to decide the vertical and horizontal sizes of graphs and images in advance. --You can't display too many at once.

It has the drawback that it is not easy to use. Especially, it is difficult to use when you want to easily view the image on jupyter.

Shouldn't we just plt.show () line by line?

The first idea I came up with was to call subplots () and show () line by line.

subplots_1row.py


import matplotlib.pyplot as plt

ncols = 5
f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
for i in range(13):
    axes[i % 5].plot(range(10), range(i,10+i))  #Appropriate graph
    if i % 5 == 4:
        plt.show()
        f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
plt.show()

When displaying a large number of images, they are displayed little by little, so the feeling of waiting is greatly reduced. You don't have to think about the layout because you only need the number of graphs for one line and the vertical and horizontal sizes.

This is fine if you want to incorporate it into a simple loop. However, the process of checking axes is cumbersome.

Wouldn't it be nice to have a generator?

I don't want to write the troublesome process many times, so I made it a generator.

axes_generator.py


import matplotlib.pyplot as plt

def axes_generator(ncols):
    while True:
        f, axes = plt.subplots(1, ncols, figsize=(ncols*2.0, 2.0))
        for c in range(ncols):
            yield axes[c]
        plt.show()

ag = axes_generator(5)
for i, ax in zip(range(13), ag):
    ax.plot(range(10), range(i,10+i))  #Appropriate graph
plt.show()

Thanks to the generator, the process of drawing the graph can be devoted to it. The inside of the loop is refreshing and easy to read.

If you want to add a graph, call __next__ () and the iterator will return the next Axes.

ag = axes_generator(5)
for i, ax in zip(range(13), ag):
    ax.plot(range(10), range(i,10+i))  #Appropriate graph
    if i % 3 == 2:
        ax = ag.__next__()
        ax.bar(range(5), range(i,i+5))  #Appropriate graph
plt.show()

But it's awkward to see the extra Axes.

Wouldn't it be nice to have a generator class?

I wanted to hide the excess, so I decided to classify it. --Can be used as a generator. --Axes created by subplots () are managed by members. --A method to get Axes is prepared so that you can add a graph. --Before the last plt.show (), I erased the extra axis of Axes to make Axes invisible.

AxesGenerator.py


import matplotlib.pyplot as plt

class AxesGenerator:
    def __init__(self, ncols:int=6, figsize:tuple=None, *args, **kwargs):
        self._ncols = ncols
        self._figsize = figsize
        self._axes = []

    def __iter__(self):
        while True:
            yield self.get()

    def get(self):
        if len(self._axes) == 0:
            plt.show()
            f, axes = plt.subplots(nrows=1, ncols=self._ncols, figsize=self._figsize)
            self._axes = list(axes) if self._ncols > 1 else [axes,]
        ax = self._axes.pop(0)
        return ax

    def flush(self):
        for ax in self._axes:
            ax.axis('off')
        plt.show()
        self._axes = []

ncols = 5
ag = AxesGenerator(ncols, figsize=(ncols*2.0, 2.0))
for i, ax in zip(range(13), ag):
    ax.plot(range(10), range(i,10+i))  #Appropriate graph
    if i % 3 == 2:
        ax = ag.get()
        ax.bar(range(5), range(i,i+5))  #Appropriate graph
ag.flush()

It's getting pretty good. However, I'm calling flush () to clean up, but I'm mistakenly using plt.show (), forgetting to write it, or something else.

If cleaning up is a hassle, why not leave it to the with sentence?

You should be able to use the with statement to clean up. For with, \ _ \ _ enter \ _ \ _ () and \ _ \ _ exit \ _ \ _ () were added.

By the way, the constructor now accepts the arguments of subplots ().

AxesGenerator.py


import matplotlib.pyplot as plt

class AxesGenerator:
    def __init__(self, ncols:int=6, sharey=False, subplot_kw=None, gridspec_kw=None, **fig_kw):
        self._ncols = ncols
        self._sharey = sharey
        self._subplot_kw = subplot_kw
        self._gridspec_kw = gridspec_kw
        self._fig_kw = fig_kw
        self._axes = []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.flush()
        return True  #Exception handling is omitted

    def __iter__(self):
        while True:
            yield self.get()

    def get(self):
        if len(self._axes) == 0:
            plt.show()
            f, axes = plt.subplots(nrows=1, ncols=self._ncols, sharey=self._sharey, subplot_kw=self._subplot_kw, gridspec_kw=self._gridspec_kw, **self._fig_kw)
            self._axes = list(axes) if self._ncols > 1 else [axes,]
        ax = self._axes.pop(0)
        return ax

    def flush(self):
        for ax in self._axes:
            ax.axis('off')
        plt.show()
        self._axes = []

ncols = 5
with AxesGenerator(ncols, figsize=(ncols*2.0, 2.0)) as ag:
    for i, ax in zip(range(13), ag):
        ax.plot(range(10), range(i,10+i))  #Appropriate graph
        if i % 3 == 2:
            ax = ag.get()
            ax.bar(range(5), range(i,i+5))  #Appropriate graph

in conclusion

I was able to achieve the goal of quickly viewing a large number of graphs and images. Now you can even view all 60,000 MNIST teacher images. (I think if you really do it, it will fall on the way)

show_mnist_OK.py


import matplotlib.cm as cm
import matplotlib.pyplot as plt
import keras

mnist = keras.datasets.mnist.load_data()
(X_train, t_train), (X_test, t_test) = mnist

with AxesGenerator(ncols=20, figsize=(20,1)) as ag:
    for x, ax in zip(X_train, ag):
        ax.imshow(x, cmap=cm.gray)
        ax.axis('equal')
        ax.axis('off')

I feel like I can make the code a little smarter, but someday.

Recommended Posts

Create an Ax generator and draw an infinite graph
Create a partial correlation matrix and draw an independent graph
Draw an Earth-like flow animation with matplotlib and cartopy