Vous souhaiterez peut-être ajouter des informations textuelles autour de la figure. Vous pouvez ajouter des informations de caractère avec plt.text
, mais c'est gênant car vous devez spécifier l'emplacement un par un. Le code suivant est ajouté en utilisant legend
sans spécifier l'emplacement.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
plt.rcParams['font.size'] = 15
def r2(y1, y2):
r2 = str(np.round(np.corrcoef(y1, y2)[0,1],3))
return r2
xa_train = [1,3,5,7]
xa_test = [2,4,6,9]
ya_train = xa_train + np.random.randn(4)
ya_test = xa_test + np.random.randn(4)
xb_train = [1,3,5,7]
xb_test = [2,4,6,9]
yb_train = xb_train + np.random.randn(4)
yb_test = xb_test + np.random.randn(4)
plt.figure()
plt.subplots_adjust(wspace=0.2, hspace=0.4)
gs = gridspec.GridSpec(2, 2, width_ratios=[1,1], height_ratios=[4,1])
plt.subplot(gs[0])
plt.scatter(xa_train, ya_train, color='k', label='train')
plt.scatter(xa_test, ya_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('train')
plt.xlabel('measured')
plt.ylabel('predicted')
plt.subplot(gs[1])
plt.scatter(xb_train, yb_train, color='k',label='train')
plt.scatter(xb_test, yb_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,0],[10,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('test')
plt.xlabel('measured')
plt.tick_params(left=False,labelleft=False)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
plt.subplot(gs[2])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xa_train,ya_train)+'\n$R^2_{test}=$'+r2(xa_test,ya_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.subplot(gs[3])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xb_train,yb_train)+'\n$R^2_{test}=$'+r2(xb_test,yb_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.show()
Recommended Posts