[PYTHON] 3D plot with matplotlib

A brief summary for plotting in 3D with matplotlib. Let's plot the probability density function of the bivariate normal distribution in 3D.

See the official tutorial for details.

Setting

Import what you need for the time being. Also set the number of dimensions and parameters of the normal distribution.

import matplotlib
print(matplotlib.__version__)
# 1.5.1

import numpy as np
from scipy.stats import multivariate_normal

#for plotting
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

m = 2 #dimension
mean = np.zeros(m)
sigma = np.eye(m)

2018/4/17 added

There seems to be no major change in the latest version (ver 2.2.2 stable version). See here for details. The mplot3d Toolkit

Various plots

Surface Plot Try Surface Plot (Surface plot in Japanese?). Note that the data passed to the plot_surface function is a two-dimensional array.

N = 1000
x1 = np.linspace(-5, 5, N)
x2 = np.linspace(-5, 5, N)

X1, X2 = np.meshgrid(x1, x2)
X = np.c_[np.ravel(X1), np.ravel(X2)]

Y_plot = multivariate_normal.pdf(x=X, mean=mean, cov=sigma)
Y_plot = Y_plot.reshape(X1.shape)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X1, X2, Y_plot, cmap='bwr', linewidth=0)
fig.colorbar(surf)
ax.set_title("Surface Plot")
fig.show()

# X1.shape : (1000, 1000)
# X2.shape : (1000, 1000)
# Y_plot.shape : (1000, 1000)

surface.png

Contour Plot Contour Plot can be done in the same way as surface plot.

N = 1000
x1 = np.linspace(-5, 5, N)
x2 = np.linspace(-5, 5, N)

X1, X2 = np.meshgrid(x1, x2)
X = np.c_[np.ravel(X1), np.ravel(X2)]
Y_plot = multivariate_normal.pdf(x=X, mean=mean, cov=sigma)
Y_plot = Y_plot.reshape(X1.shape)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.contour(X1, X2, Y_plot)
ax.set_title("Contour Plot")
fig.show()

# X1.shape : (1000, 1000)
# X2.shape : (1000, 1000)
# Y_plot.shape : (1000, 1000)

contor.png

Scatter Plot Unlike before, the data passed to the scatter plot is a one-dimensional array.

N = 100
x1 = np.linspace(-5, 5, N)
x2 = np.linspace(-5, 5, N)

X1, X2 = np.meshgrid(x1, x2)
X_plot = np.c_[np.ravel(X1), np.ravel(X2)]

y = multivariate_normal.pdf(X_plot, mean=mean, cov=sigma)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(np.ravel(X1), np.ravel(X2), y)
ax.set_title("Scatter Plot")
plt.show()

# np.ravel(X1).shape : (10000,)
# np.ravel(X2).shape : (10000,)
# y.shape : (10000,)

scatter.png

Scatter plot is not for when you want to see the shape of such a function, so it can't be helped that it is hard to see.

About Axes3D

Looking at other articles, there are examples of creating ax objects for 3D as follows, but

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = Axes3D(fig)
#<class 'mpl_toolkits.mplot3d.axes3d.Axes3D'>

In recent versions, it seems recommended to use this as per the tutorial.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#<class 'matplotlib.axes._subplots.Axes3DSubplot'>

By the way, Axes3D is not explicitly used, but if you do not import it, you will get KeyError: '3d'.

Also, it seems that you can create an ax object of the same class as follows.

fig = plt.figure()
ax = fig.gca(projection='3d')
#<class 'matplotlib.axes._subplots.Axes3DSubplot'>

Other

There seems to be plotly that plots 3D nicely with python, so I'd like to find out soon.

Recommended Posts

3D plot with matplotlib
2-axis plot with Matplotlib
2D plot in matplotlib
3D scatter plot with PyQtGraph
Stackable bar plot with matplotlib
Create 3D scatter plot with SciPy + matplotlib (Python)
Create plot animation with Python + Matplotlib
Continuously color with matplotlib scatter plot
Lognormal probability plot with Python, matplotlib
Animation with matplotlib
Japanese with matplotlib
Animation with matplotlib
Histogram with matplotlib
Animate with matplotlib
Draw a flat surface with a matplotlib 3d graph
3D plot Pandas DataFrame
[Python] limit axis of 3D graph with Matplotlib
Plot ROC Curve for Binary Classification with Matplotlib
Heatmap with Python + matplotlib
Learn with Cheminformatics Matplotlib
3D display with plotly
Various colorbars with Matplotlib
3D or D with Py
(Memorandum) Make a 3D scatter plot with matplodlib
Time series plot / Matplotlib
Adjust axes with matplotlib
[Scientific / technical calculation by Python] Plot, visualize, matplotlib 2D data with error bars
[Python] How to create a 2D histogram with Matplotlib
[Python] How to draw a scatter plot with Matplotlib
Let's play with 4D 4th
Candle chart plot with plotly
Graph Excel data with matplotlib (1)
Try using matplotlib with PyCharm
Create 3d gif with python3
Graph drawing method with matplotlib
Graph Excel data with matplotlib (2)
Gradient color selection with matplotlib
Animate multiple graphs with matplotlib
Interactive plot of 3D graph
Interpolate 2D data with scipy.interpolate.griddata
Reformat the timeline of the pandas time series plot with matplotlib
[Python] I want to make a 3D scatter plot of the epicenter with Cartopy + Matplotlib!
R & D life with iPython notebook
A python graphing manual with Matplotlib.
Japaneseize Matplotlib with Alpine using Docker
[Python] font family and font with matplotlib
Solve ABC166 A ~ D with Python
Draw a loose graph with matplotlib
Versatile data plotting with pandas + matplotlib
[Introduction to Matplotlib] Axes 3D animation: I played with 3D Lissajous figures ♬
Heatmap with Dendrogram in Python + matplotlib
3D drawing with SceneKit in Pythonista
Easy Japanese font setting with matplotlib
Easy to draw graphs with matplotlib
Cases using pandas plot, cases using (pure) matplotlib plot
Draw Lyapunov Fractal with Python, matplotlib
When matplotlib doesn't work with python2.7
Easy animation with matplotlib (mp4, gif)
Write a stacked histogram with matplotlib
How to display legend marks in one with Python 2D plot
Implement "Data Visualization Design # 2" with matplotlib