Je n'étais pas sûr de la divergence KL qui apparaît dans l'algorithme EM, donc j'aimerais obtenir une image en trouvant la divergence KL entre les distributions normales.
La divergence Kullback-Leibler (divergence KL, information KL) est une mesure de la similitude des deux distributions de probabilité. La définition est la suivante.
KL(p||q) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx
Il y a deux caractéristiques importantes. Le premier est qu'il sera égal à 0 pour la même distribution de probabilité.
KL(p||p) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{p(x)}dx
= \int_{-\infty}^{\infty}p(x)\ln(1)dx
= 0
La seconde est que ce sera toujours une valeur positive comprenant 0, et plus les distributions de probabilité sont dissemblables, plus la valeur est élevée. Regardons ces caractéristiques en utilisant un exemple de distribution normale.
Les fonctions de densité de probabilité p (x) et q (x) de la distribution normale sont définies comme suit.
p(x) = N(\mu_1,\sigma_1^2) = \frac{1}{\sqrt{2\pi\sigma_1^2}} \exp\left(-\frac{(x-\mu_1)^2}{2\sigma_1^2}\right) \\
q(x) = N(\mu_2,\sigma_2^2) = \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{(x-\mu_2)^2}{2\sigma_2^2}\right)
Trouvez la divergence KL entre les deux distributions normales ci-dessus. Le calcul est omis.
\begin{eqnarray}
KL(p||q)&=& \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx \\
&=& \cdots \\
&=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{eqnarray}
Puisqu'il est difficile de comprendre s'il y a quatre variables, soit $ p (x) $ la distribution normale standard $ N (0,1) $ avec une moyenne de 0 et une variance de 1.
p(x) =N(0,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right)
Tout d'abord, définissez l'écart type de $ q (x) $ sur $ \ sigma_2 $ et ne définissez que la moyenne $ \ mu_2 $ comme variables.
q(x) =N(\mu_2,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x-\mu_2)^2}{2}\right)
La divergence KL à ce moment est
\begin{eqnarray}
KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\frac{1}{1}\right) + \frac{1^2+(\mu_1-0)^2}{2*1^2} - \frac{1}{2} \\
&=& \frac{\mu_2^2}{2}
\end{eqnarray}
Ce sera.
La ligne orange à gauche est $ q (x) $ lorsque la moyenne $ \ mu_2 $ est modifiée. La figure de droite est celle où la moyenne $ \ mu_2 $ est prise sur l'axe des x. La ligne bleue est la solution analytique et le point orange est la valeur de divergence KL actuelle. Il a été confirmé que la divergence KL devient 0 lorsque $ p (x) $ et $ q (x) $ correspondent exactement, et augmente à mesure que la distance augmente.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#distribution normale
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
#Encoche x
dx = 0.01
#Gamme de x
xlm = [-6,6]
#coordonnée x
x = np.arange(xlm[0],xlm[1]+dx,dx)
#Nombre de x
x_n = len(x)
# Case 1
# p(x) = N(0,1)
# q(x) = N(μ,1)
# p(x)Moyenne μ1
μ1 = 0
# p(x)Écart type σ1
σ1 = 1
# p(x)
px = gaussian1d(x,μ1,σ1)
# q(x)Écart type σ2
σ2 = 1
# q(x)Moyenne μ2
U2 = np.arange(-4,5,1)
U2_n = len(U2)
# q(x)
Qx = np.zeros([x_n,U2_n])
#Divergence KL
KL_U2 = np.zeros(U2_n)
for i,μ2 in enumerate(U2):
qx = gaussian1d(x,μ2,σ2)
Qx[:,i] = qx
KL_U2[i] = KLdivergence(px,qx,dx)
#Portée de la solution analytique
U2_exc = np.arange(-4,4.1,0.1)
#Solution analytique
KL_U2_exc = gaussian1d_KLdivergence(μ1,σ1,U2_exc,σ2)
#Solution analytique 2
KL_U2_exc2 = U2_exc**2 / 2
#
# plot
#
# figure
fig = plt.figure(figsize=(8,4))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
# axis 1
#-----------------------
#Diagramme de distribution normale
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')
#Guide d'utilisation
plt.legend(loc=1,prop={'size': 13})
plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')
# axis 2
#-----------------------
#Divergence KL
ax2 = plt.subplot(1,2,2)
#Solution analytique
plt.plot(U2_exc,KL_U2_exc,label='Analytical')
#Calcul
point, = ax2.plot([],'o',label='Numerical')
#Guide d'utilisation
# plt.legend(loc=1,prop={'size': 15})
plt.xlim([U2[0],U2[-1]])
plt.xlabel('$\mu$')
plt.ylabel('$KL(p||q)$')
plt.tight_layout()
#Paramètres communs pour les axes
for a in [ax,ax2]:
plt.axes(a)
plt.grid()
#Dans un carré
plt.gca().set_aspect(1/plt.gca().get_data_ratio())
#mise à jour
def update(i):
#ligne
line.set_data(x,Qx[:,i])
#point
point.set_data(U2[i],KL_U2[i])
#Titre
ax.set_title("$\mu_2=%.1f$" % U2[i],fontsize=15)
ax2.set_title('$KL(p||q)=%.1f$' % KL_U2[i],fontsize=15)
#animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=U2_n)
# plt.show()
# ani.save("KL_μ.gif", writer="imagemagick")
Ensuite, la moyenne $ \ mu_2 $ de $ q (x) $ est mise à 0, et seul l'écart type $ \ sigma_2 $ est défini comme une variable.
q(x) =N(0,\sigma^2_2)= \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{x^2}{2\sigma_2^2}\right)
La divergence KL à ce moment est
\begin{eqnarray}
KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\frac{\sigma_2}{1}\right) + \frac{1^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\sigma_2\right) + \frac{1}{2\sigma_2^2} - \frac{1}{2} \\
\end{eqnarray}
Ce sera.
Comme précédemment, le changement de divergence KL est devenu 0 lorsque les distributions de probabilité correspondaient, et a augmenté à mesure que la forme changeait.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#distribution normale
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
#Encoche x
dx = 0.01
#Gamme de x
xlm = [-6,6]
#coordonnée x
x = np.arange(xlm[0],xlm[1]+dx,dx)
#Nombre de x
x_n = len(x)
# Case 2
# p(x) = N(0,1)
# q(x) = N(0,σ**2)
# p(x)Moyenne μ1
μ1 = 0
# p(x)Écart type σ1
σ1 = 1
# p(x)
px = gaussian1d(x,μ1,σ1)
# q(x)Moyenne μ2
μ2 = 0
# q(x)Écart type σ2
S2 = np.hstack([ np.arange(0.5,1,0.1),np.arange(1,2,0.2),np.arange(2,4.5,0.5) ])
S2_n = len(S2)
# q(x)
Qx = np.zeros([x_n,S2_n])
#Divergence KL
KL_S2 = np.zeros(S2_n)
for i,σ2 in enumerate(S2):
qx = gaussian1d(x,μ2,σ2)
Qx[:,i] = qx
KL_S2[i] = KLdivergence(px,qx,dx)
#Portée de la solution analytique
S2_exc = np.arange(0.5,4+0.05,0.05)
#Solution analytique
KL_S2_exc = gaussian1d_KLdivergence(μ1,σ1,μ2,S2_exc)
#Solution analytique 2
KL_S2_exc2 = np.log(S2_exc) + 1/(2*S2_exc**2) - 1 / 2
#
# plot
#
# figure
fig = plt.figure(figsize=(8,4))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
# axis 1
#-----------------------
#Diagramme de distribution normale
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')
#Guide d'utilisation
plt.legend(loc=1,prop={'size': 13})
plt.ylim([0,0.8])
plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')
# axis 2
#-----------------------
#Divergence KL
ax2 = plt.subplot(1,2,2)
#Solution analytique
plt.plot(S2_exc,KL_S2_exc,label='Analytical')
#Calcul
point, = ax2.plot([],'o',label='Numerical')
#Guide d'utilisation
# plt.legend(loc=1,prop={'size': 15})
plt.xlim([S2[0],S2[-1]])
plt.xlabel('$\sigma$')
plt.ylabel('$KL(p||q)$')
plt.tight_layout()
#Paramètres communs pour les axes
for a in [ax,ax2]:
plt.axes(a)
plt.grid()
#Dans un carré
plt.gca().set_aspect(1/plt.gca().get_data_ratio())
#mise à jour
def update(i):
#ligne
line.set_data(x,Qx[:,i])
#point
point.set_data(S2[i],KL_S2[i])
#Titre
ax.set_title("$\sigma_2=%.1f$" % S2[i],fontsize=15)
ax2.set_title('$KL(p||q)=%.1f$' % KL_S2[i],fontsize=15)
#animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=S2_n)
plt.show()
# ani.save("KL_σ.gif", writer="imagemagick")
Voici un graphique des valeurs de divergence KL lorsque la moyenne $ \ mu_2 $ et l'écart type $ \ sigma_2 $ sont modifiés.
import numpy as np
import matplotlib.pyplot as plt
#distribution normale
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
def Motion(event):
global cx,cy,cxid,cyid
xp = event.xdata
yp = event.ydata
if (xp is not None) and (yp is not None):
gca = event.inaxes
if gca is axs[0]:
cxid,cx = find_nearest(x,xp)
cyid,cy = find_nearest(y,yp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[1].set_data(x,Z[:,cyid])
lns[2].set_data(y,Z[cxid,:])
lnhs[0].set_ydata([cy,cy])
lnvs[0].set_xdata([cx,cx])
lnvs[1].set_xdata([cx,cx])
lnvs[2].set_xdata([cy,cy])
if gca is axs[2]:
cxid,cx = find_nearest(x,xp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[2].set_data(y,Z[cxid,:])
lnvs[0].set_xdata([cx,cx])
lnvs[1].set_xdata([cx,cx])
if gca is axs[3]:
cyid,cy = find_nearest(y,xp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[1].set_data(x,Z[:,cyid])
lnhs[0].set_ydata([cy,cy])
lnvs[2].set_xdata([cy,cy])
axs[1].set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)
axs[0].set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)
plt.draw()
def find_nearest(array, values):
id = np.abs(array-values).argmin()
return id,array[id]
#Encoche x
G_dx = 0.01
#Gamme de x
G_xlm = [-4,4]
#coordonnée x
G_x = np.arange(G_xlm[0],G_xlm[1]+G_dx,G_dx)
#Nombre de x
G_n = len(G_x)
# p(x)Moyenne μ1
μ1 = 0
# p(x)Écart type σ1
σ1 = 1
# p(x)
px = gaussian1d(G_x,μ1,σ1)
# q(x)Moyenne μ2
μ_lim = [-2,2]
μ_dx = 0.1
μ_x = np.arange(μ_lim[0],μ_lim[1]+μ_dx,μ_dx)
μ_n = len(μ_x)
# q(x)Écart type σ2
σ_lim = [0.5,4]
σ_dx = 0.05
σ_x = np.arange(σ_lim[0],σ_lim[1]+σ_dx,σ_dx)
σ_n = len(σ_x)
#Divergence KL
KL = np.zeros([μ_n,σ_n])
# q(x)
Qx = np.zeros([G_n,μ_n,σ_n])
for i,μ2 in enumerate(μ_x):
for j,σ2 in enumerate(σ_x):
KL[i,j] = gaussian1d_KLdivergence(μ1,σ1,μ2,σ2)
Qx[:,i,j] = gaussian1d(G_x,μ2,σ2)
x = μ_x
y = σ_x
X,Y = np.meshgrid(x,y)
Z = KL
cxid = 0
cyid = 0
cx = x[cxid]
cy = y[cyid]
xlm = [ x[0], x[-1] ]
ylm = [ y[0], y[-1] ]
axs = []
ims = []
lns = []
lnvs = []
lnhs = []
# figure
#----------------
plt.close('all')
plt.figure(figsize=(8,8))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
#taille de police
plt.rcParams["font.size"] = 16
#Largeur de ligne
plt.rcParams['lines.linewidth'] = 2
#Faire du style de ligne de la grille une ligne pointillée
plt.rcParams["grid.linestyle"] = '--'
#Éliminez les marges de plage lors du traçage
plt.rcParams['axes.xmargin'] = 0.
# ax1
#----------------
ax = plt.subplot(2,2,1)
Interval = np.arange(0,8,0.1)
plt.plot(μ1,σ1,'rx',label='$(μ_1,σ_1)=(0,1)$')
im = plt.contourf(X,Y,Z.T,Interval,cmap='hot')
lnv= plt.axvline(x=cx,color='w',linestyle='--',linewidth=1)
lnh= plt.axhline(y=cy,color='w',linestyle='--',linewidth=1)
ax.set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)
plt.xlabel('μ')
plt.ylabel('σ')
axs.append(ax)
lnhs.append(lnh)
lnvs.append(lnv)
ims.append(im)
# ax2
#----------------
ax = plt.subplot(2,2,2)
plt.plot(G_x,px,label='$p(x)$')
ln, = plt.plot(G_x,Qx[:,cxid,cyid],color=clr[1],label='$q(x)$')
plt.legend(prop={'size': 10})
ax.set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)
axs.append(ax)
lns.append(ln)
plt.grid()
# ax3
#----------------
ax = plt.subplot(2,2,3)
ln,=plt.plot(x,Z[:,cyid])
lnv= plt.axvline(x=cx,color='k',linestyle='--',linewidth=1)
plt.ylim([0,np.max(Z)])
plt.grid()
plt.xlabel('μ')
plt.ylabel('KL(p||q)')
lnvs.append(lnv)
axs.append(ax)
lns.append(ln)
# ax4
#----------------
ax = plt.subplot(2,2,4)
ln,=plt.plot(y,Z[cxid,:])
lnv= plt.axvline(x=cy,color='k',linestyle='--',linewidth=1)
plt.ylim([0,np.max(Z)])
plt.xlim([ylm[0],ylm[1]])
plt.grid()
plt.xlabel('σ')
plt.ylabel('KL(p||q)')
lnvs.append(lnv)
axs.append(ax)
lns.append(ln)
plt.tight_layout()
for ax in axs:
plt.axes(ax)
ax.set_aspect(1/ax.get_data_ratio())
plt.connect('motion_notify_event', Motion)
plt.show()