[PYTHON] Manage the overlap when drawing scatter plots with a large amount of data (Matplotlib, Pandas, Datashader)

If you draw a scatter plot with a large number of data points, it will be so crowded that you will not be able to see how much data exists in a certain area.

As an example, consider the following data obtained by compressing the handwritten digit image data set (MNIST) in two dimensions with UMAP.

import pandas as pd

df = pd.read_csv('./mnist_embedding.csv', index_col=0)
display(df)
x y class
0 1.273394 1.008444 5
1 12.570375 0.472456 0
2 -2.197421 8.652475 4
3 -5.642218 -4.971571 1
4 -3.874749 5.150311 9
... ... ... ...
69995 -0.502520 -7.309745 2
69996 3.264405 -0.887491 3
69997 -4.995078 8.153721 4
69998 -0.226225 -0.188836 5
69999 8.405535 -2.277809 6

70000 rows × 3 columns

x is the X coordinate, y is the Y coordinate, and class is the label (which number from 0 to 9 is written).

Try to draw a scatter plot with matplotlib normally. By the way, although it is not the main point, the recently added `` `legend_elements``` function makes it easy to create a legend for scatter plots of multiple categories without turning the for statement.

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 12))

sc = ax.scatter(df['x'], df['y'], c=df['class'], cmap='Paired', s=6, alpha=1.0)

ax.add_artist(ax.legend(*sc.legend_elements(), loc="upper right", title="Classes"))
plt.axis('off')
plt.show()

output_3_0.png

70,000 points are plotted. It's nice to have separate clusters for each number, but with such a large data size, the dots are so dense that they overlap and fill in, making the structure within each class almost invisible. I want to do something about this.

Solution 1: Adjust size and alpha and do your best

To avoid overlap, reduce the size of the points or adjust the transparency of the points to make the density easier to see. It requires trial and error and is not always easy to see.

fig, ax = plt.subplots(figsize=(12, 12))

sc = ax.scatter(df['x'], df['y'], c=df['class'], cmap='Paired', s=3, alpha=0.1)

ax.add_artist(ax.legend(*sc.legend_elements(), loc="upper right", title="Classes"))
plt.axis('off')
plt.show()

output_7_0.png

Solution 2: Hexagonal Binning

This is also a good way to do it. The canvas is laid out with a hexagonal grid, and the number of data points in each is aggregated and expressed in color depth. Easy to use Pandas plot functions.

fig, ax = plt.subplots(figsize=(12, 12))

df.plot.hexbin(x='x', y='y', gridsize=100, ax=ax)

plt.axis('off')
plt.show()

output_10_0.png

Solution 3: Use Datashader

It is versatile and easy to use. As long as you get used to it.

Datashader is a library that quickly generates "rasterized plots" for large datasets.

After deciding the resolution (number of pixels) of the figure to be output first, the data is aggregated for each pixel and output as an image, which is the three steps of drawing. Since each step can be finely adjusted, the degree of freedom is high.

Each step will be described later, but if you write them all with the default settings, it will be as follows.

import datashader as ds
from datashader import transfer_functions as tf

tf.shade(ds.Canvas().points(df,'x','y'))

output_13_0.png

Setting of each step

In Datashader

  1. Set the canvas

  2. Aggregate function settings and calculations

  3. Convert to image

Make a plot in three steps. Each is explained below.

1. Set the canvas

datashader.Set various canvases with Canvas. Vertical and horizontal resolution (pixels), logarithmic axis or not, numerical range (xlim in matplotlib),ylim) etc.




```python
canvas = ds.Canvas(plot_width=600, plot_height=600, #600 pixels in height and width
                   x_axis_type='linear', y_axis_type='linear', # 'linear' or 'log'
                   x_range=(-10,15), y_range=(-15,10))

2. Aggregate function settings and calculations

I made a canvas of (600 x 600) pixels above. Here we set how to aggregate the data for each of these pixels. For example, change the color density according to the count of data points that enter a pixel, or make it a binary value whether or not even one data point is included.

For example, for the canvas variable set above, enter the data frame, x-axis coordinates (column name), y-axis coordinates, and aggregate function as shown below, and execute the calculation. The datashader.reductions.count function counts the number of data points that go into a pixel.

canvas.points(df, 'x', 'y', agg=ds.count())
<xarray.DataArray (y: 600, x: 600)>
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)
Coordinates:
  * x        (x) float64 -9.979 -9.938 -9.896 -9.854 ... 14.85 14.9 14.94 14.98
  * y        (y) float64 -14.98 -14.94 -14.9 -14.85 ... 9.854 9.896 9.938 9.979

In this way, drawing data was generated by counting the number of data points in a matrix with a size of (600 x 600).

If you want to aggregate by binary value of whether or not data points are entered instead of counting, you can use the `` `datashader.reductions.any``` function and do as follows.

canvas.points(df, 'x', 'y', agg=ds.any())
<xarray.DataArray (y: 600, x: 600)>
array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])
Coordinates:
  * x        (x) float64 -9.979 -9.938 -9.896 -9.854 ... 14.85 14.9 14.94 14.98
  * y        (y) float64 -14.98 -14.94 -14.9 -14.85 ... 9.854 9.896 9.938 9.979

3. Conversion to image

To convert to an image, use the `shade``` function of datashader.transfer_functions```. Pass the aggregated matrix data calculated above to the argument of the `` shadefunction. In addition, various transfer_functions are prepared, and you can fine-tune the image output. Here, the result of counting and totaling is made into a white background with the `` `set_background function and imaged.

tf.set_background(tf.shade(canvas.points(df,'x','y', agg=ds.count())), 'white')

output_26_0.png

The shading is expressed according to the density of the data points, making the structure much easier to see.

In the same way, try the case of totaling by binary value of whether or not data points are included.

tf.set_background(tf.shade(canvas.points(df,'x','y', agg=ds.any())), 'white')

output_28_0.png

Aggregate with other auxiliary data

Until now, data was aggregated using only the coordinate information of the data, but it is often the case that each data point has a label of some category or a continuous value is assigned.

Since such information is not reflected by simply counting the data points that enter the pixel, there is a special aggregate function for each.

Aggregation when auxiliary data is categorical variables

In the case of MNIST, there is a label for the correct answer class, so I want to color-code it properly and plot it. As an aggregate function for that, there is datashader.reductions.count_cat. This function counts the number of data points that go into a pixel for each label. In other words, in the case of MNIST, 10 (600 x 600) aggregate matrices will be created.

In order to use count_cat, the label data needs to be Pandas category type (int type is not good), so first convert the label string of the data frame to category type.

df['class'] = df['class'].astype('category')

Aggregate with count_cat. Unlike the `` `countandany``` aggregate functions, you need to specify the column name of which column in the data frame represents the label.

agg = canvas.points(df, 'x', 'y', ds.count_cat('class'))

The color of each label is defined in a dictionary using the label as a key. Extract the "Paired" color from matplotlib to match the color of the figure when drawn at the beginning. Easy to use dictionary-type list comprehension.

import matplotlib
color_key = {i:matplotlib.colors.rgb2hex(c[:3]) for i, c 
             in enumerate(matplotlib.cm.get_cmap('Paired', 10).colors)}
print(color_key)
{0: '#a6cee3', 1: '#1f78b4', 2: '#b2df8a', 3: '#fb9a99', 4: '#e31a1c', 5: '#fdbf6f', 6: '#cab2d6', 7: '#6a3d9a', 8: '#ffff99', 9: '#b15928'}

Try to image it. It seems that the color of each pixel is drawn by mixing each color according to the number of labels of data points that enter the pixel.

tf.set_background(tf.shade(agg, color_key=color_key), 'white')

output_39_0.png

Aggregation when auxiliary data is continuous value

Some kind of continuous value may be associated with each data point. For example, in single-cell analysis, when dimensionally compressed figures of tens of thousands of cells are used, the color depth is changed by some gene expression level for each cell.

Since a pixel contains multiple data points, a representative value must be determined in some way. As an aggregate function for that, simple statistics such as max, mean, mode are prepared.

MNIST does not have continuous value auxiliary data, so try to make it appropriately. As an easy-to-understand amount, let's calculate the average brightness of the central area of the image. Zero should be dark (because the line rarely runs in the middle of the image), and 1 should be bright.

data = pd.read_csv('./mnist.csv').values[:, :784]
data.shape
(70000, 784)
#It's a 28 x 28 size image.
upper_left = 28 * 13 + 14
upper_right = 28 * 13 + 15
bottom_left = 28 * 14 + 14
bottom_right = 28 * 14 + 15

average_center_area = data[:, [upper_left, upper_right, 
                               bottom_left, bottom_right]].mean(axis=1)

First, try drawing with matplotlib normally.

fig, ax = plt.subplots(figsize=(12, 12))

sc = ax.scatter(df['x'], df['y'], c=average_center_area, cmap='viridis', 
                vmin=0, vmax=255, s=6, alpha=1.0)

plt.colorbar(sc)
plt.axis('off')
plt.show()

output_45_0.png

After all it is crushed and I do not understand well.

Pass it to the Datashader and try to paint it according to the "maximum value" of the data points contained in each pixel. It can be aggregated with the datashader.reductions.max function.

df['value'] = average_center_area
agg = canvas.points(df, 'x', 'y', agg=ds.max('value'))
tf.set_background(tf.shade(agg, cmap=matplotlib.cm.get_cmap('viridis')), 'white')

output_47_0.png

It's easier to see. It may not be much different from adjusting the size to a smaller size with matplotlib scatter, but it is convenient to be able to draw beautifully without detailed trial and error.

Also, even if the data size is huge, it is fast, so it is not stressful to make various adjustments such as what happens when totaling with average values.

agg = canvas.points(df, 'x', 'y', agg=ds.mean('value'))
tf.set_background(tf.shade(agg, cmap=matplotlib.cm.get_cmap('viridis')), 'white')

output_49_0.png

Recommended Posts

Manage the overlap when drawing scatter plots with a large amount of data (Matplotlib, Pandas, Datashader)
A collection of methods used when aggregating data with pandas
I want to solve the problem of memory leak when outputting a large number of images with Matplotlib
A memorandum of method often used when analyzing data with pandas (for beginners)
Try to create a battle record table with matplotlib from the data of "Schedule-kun"
When reading a csv file with read_csv of pandas, the first column becomes index
[Python] I want to make a 3D scatter plot of the epicenter with Cartopy + Matplotlib!
A network diagram was created with the data of COVID-19.
Notes on handling large amounts of data with python + pandas
Introducing the potential of Plotly scatter plots with practical examples
Reformat the timeline of the pandas time series plot with matplotlib
When generating a large number of graphs with matplotlib, I do not want to display the graph on the screen (jupyter environment)
How to create a large amount of test data in MySQL? ??
Show labels for each element when drawing scatter plots in Pandas
Versatile data plotting with pandas + matplotlib
Do not change the order of columns when concatenating pandas data frames.
I made a mistake in fetching the hierarchy with MultiIndex of pandas
The result was better when the training data of the mini-batch was made a hybrid of fixed and random with a neural network.
Align the size of the colorbar with matplotlib
Try drawing a normal distribution with matplotlib
Example of efficient data processing with PANDAS
A memorandum of trouble when formatting data
Introduction of drawing code for figures with a certain degree of perfection of meteorological data
[Verification] Does levelDB take time to register data when the amount of data increases? ??
Paste a link to the data point of the graph created by jupyterlab & matplotlib
Change the data frame of pandas purchase data (id x product) to a dictionary
Draw a line / scatter plot on the CSV file (2 columns) with python matplotlib
[Introduction to Python] How to get the index of data with a for statement