Ich war mir nicht sicher über die KL-Divergenz, die im EM-Algorithmus auftritt, daher möchte ich ein Bild erhalten, indem ich die KL-Divergenz zwischen Normalverteilungen finde.
Die Kullback-Leibler-Divergenz (KL-Divergenz, KL-Information) ist ein Maß dafür, wie ähnlich die beiden Wahrscheinlichkeitsverteilungen sind. Die Definition ist wie folgt.
KL(p||q) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx
Es gibt zwei wichtige Merkmale. Das erste ist, dass es für dieselbe Wahrscheinlichkeitsverteilung 0 ist.
KL(p||p) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{p(x)}dx
= \int_{-\infty}^{\infty}p(x)\ln(1)dx
= 0
Das zweite ist, dass es immer ein positiver Wert einschließlich 0 ist, und je unterschiedlicher die Wahrscheinlichkeitsverteilungen sind, desto größer ist der Wert. Betrachten wir diese Eigenschaften anhand eines Beispiels für eine Normalverteilung.
Die Wahrscheinlichkeitsdichtefunktionen p (x) und q (x) der Normalverteilung sind wie folgt definiert.
p(x) = N(\mu_1,\sigma_1^2) = \frac{1}{\sqrt{2\pi\sigma_1^2}} \exp\left(-\frac{(x-\mu_1)^2}{2\sigma_1^2}\right) \\
q(x) = N(\mu_2,\sigma_2^2) = \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{(x-\mu_2)^2}{2\sigma_2^2}\right)
Finden Sie die KL-Divergenz zwischen den beiden oben genannten Normalverteilungen. Die Berechnung entfällt.
\begin{eqnarray}
KL(p||q)&=& \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx \\
&=& \cdots \\
&=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{eqnarray}
Da es schwer zu verstehen ist, ob es vier Variablen gibt, sei $ p (x) $ die Standardnormalverteilung $ N (0,1) $ mit einem Mittelwert von 0 und einer Varianz von 1.
p(x) =N(0,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right)
Setzen Sie zunächst die Standardabweichung von $ q (x) $ auf $ \ sigma_2 $ und setzen Sie nur den Durchschnitt $ \ mu_2 $ als Variablen.
q(x) =N(\mu_2,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x-\mu_2)^2}{2}\right)
Die KL-Divergenz zu diesem Zeitpunkt ist
\begin{eqnarray}
KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\frac{1}{1}\right) + \frac{1^2+(\mu_1-0)^2}{2*1^2} - \frac{1}{2} \\
&=& \frac{\mu_2^2}{2}
\end{eqnarray}
Es wird sein.
Die orange Linie links ist $ q (x) $, wenn der Durchschnitt $ \ mu_2 $ geändert wird. Die Abbildung rechts ist die Abbildung, wenn der Durchschnitt $ \ mu_2 $ auf der x-Achse genommen wird. Die blaue Linie ist die analytische Lösung und der orangefarbene Punkt ist der aktuelle KL-Divergenzwert. Es wurde bestätigt, dass die KL-Divergenz 0 wird, wenn $ p (x) $ und $ q (x) $ genau übereinstimmen, und mit zunehmender Entfernung zunimmt.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#Normalverteilung
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Normalverteilung KL Divergenz
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
#Kerbe x
dx = 0.01
#Bereich von x
xlm = [-6,6]
#x-Koordinate
x = np.arange(xlm[0],xlm[1]+dx,dx)
#Anzahl von x
x_n = len(x)
# Case 1
# p(x) = N(0,1)
# q(x) = N(μ,1)
# p(x)Durchschnitt μ1
μ1 = 0
# p(x)Standardabweichung σ1
σ1 = 1
# p(x)
px = gaussian1d(x,μ1,σ1)
# q(x)Standardabweichung σ2
σ2 = 1
# q(x)Durchschnitt μ2
U2 = np.arange(-4,5,1)
U2_n = len(U2)
# q(x)
Qx = np.zeros([x_n,U2_n])
#KL-Divergenz
KL_U2 = np.zeros(U2_n)
for i,μ2 in enumerate(U2):
qx = gaussian1d(x,μ2,σ2)
Qx[:,i] = qx
KL_U2[i] = KLdivergence(px,qx,dx)
#Umfang der analytischen Lösung
U2_exc = np.arange(-4,4.1,0.1)
#Analytische Lösung
KL_U2_exc = gaussian1d_KLdivergence(μ1,σ1,U2_exc,σ2)
#Analytische Lösung 2
KL_U2_exc2 = U2_exc**2 / 2
#
# plot
#
# figure
fig = plt.figure(figsize=(8,4))
#Standardfarbe
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
# axis 1
#-----------------------
#Normalverteilungsdiagramm
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')
#Gebrauchsanweisung
plt.legend(loc=1,prop={'size': 13})
plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')
# axis 2
#-----------------------
#KL-Divergenz
ax2 = plt.subplot(1,2,2)
#Analytische Lösung
plt.plot(U2_exc,KL_U2_exc,label='Analytical')
#Berechnung
point, = ax2.plot([],'o',label='Numerical')
#Gebrauchsanweisung
# plt.legend(loc=1,prop={'size': 15})
plt.xlim([U2[0],U2[-1]])
plt.xlabel('$\mu$')
plt.ylabel('$KL(p||q)$')
plt.tight_layout()
#Allgemeine Einstellungen für Achsen
for a in [ax,ax2]:
plt.axes(a)
plt.grid()
#Auf einem Platz
plt.gca().set_aspect(1/plt.gca().get_data_ratio())
#aktualisieren
def update(i):
#Linie
line.set_data(x,Qx[:,i])
#Punkt
point.set_data(U2[i],KL_U2[i])
#Titel
ax.set_title("$\mu_2=%.1f$" % U2[i],fontsize=15)
ax2.set_title('$KL(p||q)=%.1f$' % KL_U2[i],fontsize=15)
#Animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=U2_n)
# plt.show()
# ani.save("KL_μ.gif", writer="imagemagick")
Als nächstes wird der Durchschnitt $ \ mu_2 $ von $ q (x) $ auf 0 gesetzt, und nur die Standardabweichung $ \ sigma_2 $ wird als Variable gesetzt.
q(x) =N(0,\sigma^2_2)= \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{x^2}{2\sigma_2^2}\right)
Die KL-Divergenz zu diesem Zeitpunkt ist
\begin{eqnarray}
KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\frac{\sigma_2}{1}\right) + \frac{1^2}{2\sigma_2^2} - \frac{1}{2} \\
&=& \ln\left(\sigma_2\right) + \frac{1}{2\sigma_2^2} - \frac{1}{2} \\
\end{eqnarray}
Es wird sein.
Wie zuvor wurde die Änderung der KL-Divergenz 0, wenn die Wahrscheinlichkeitsverteilungen übereinstimmten, und nahm zu, wenn sich die Form änderte.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#Normalverteilung
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Normalverteilung KL Divergenz
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
#Kerbe x
dx = 0.01
#Bereich von x
xlm = [-6,6]
#x-Koordinate
x = np.arange(xlm[0],xlm[1]+dx,dx)
#Anzahl von x
x_n = len(x)
# Case 2
# p(x) = N(0,1)
# q(x) = N(0,σ**2)
# p(x)Durchschnitt μ1
μ1 = 0
# p(x)Standardabweichung σ1
σ1 = 1
# p(x)
px = gaussian1d(x,μ1,σ1)
# q(x)Durchschnitt μ2
μ2 = 0
# q(x)Standardabweichung σ2
S2 = np.hstack([ np.arange(0.5,1,0.1),np.arange(1,2,0.2),np.arange(2,4.5,0.5) ])
S2_n = len(S2)
# q(x)
Qx = np.zeros([x_n,S2_n])
#KL-Divergenz
KL_S2 = np.zeros(S2_n)
for i,σ2 in enumerate(S2):
qx = gaussian1d(x,μ2,σ2)
Qx[:,i] = qx
KL_S2[i] = KLdivergence(px,qx,dx)
#Umfang der analytischen Lösung
S2_exc = np.arange(0.5,4+0.05,0.05)
#Analytische Lösung
KL_S2_exc = gaussian1d_KLdivergence(μ1,σ1,μ2,S2_exc)
#Analytische Lösung 2
KL_S2_exc2 = np.log(S2_exc) + 1/(2*S2_exc**2) - 1 / 2
#
# plot
#
# figure
fig = plt.figure(figsize=(8,4))
#Standardfarbe
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
# axis 1
#-----------------------
#Normalverteilungsdiagramm
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')
#Gebrauchsanweisung
plt.legend(loc=1,prop={'size': 13})
plt.ylim([0,0.8])
plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')
# axis 2
#-----------------------
#KL-Divergenz
ax2 = plt.subplot(1,2,2)
#Analytische Lösung
plt.plot(S2_exc,KL_S2_exc,label='Analytical')
#Berechnung
point, = ax2.plot([],'o',label='Numerical')
#Gebrauchsanweisung
# plt.legend(loc=1,prop={'size': 15})
plt.xlim([S2[0],S2[-1]])
plt.xlabel('$\sigma$')
plt.ylabel('$KL(p||q)$')
plt.tight_layout()
#Allgemeine Einstellungen für Achsen
for a in [ax,ax2]:
plt.axes(a)
plt.grid()
#Auf einem Platz
plt.gca().set_aspect(1/plt.gca().get_data_ratio())
#aktualisieren
def update(i):
#Linie
line.set_data(x,Qx[:,i])
#Punkt
point.set_data(S2[i],KL_S2[i])
#Titel
ax.set_title("$\sigma_2=%.1f$" % S2[i],fontsize=15)
ax2.set_title('$KL(p||q)=%.1f$' % KL_S2[i],fontsize=15)
#Animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=S2_n)
plt.show()
# ani.save("KL_σ.gif", writer="imagemagick")
Unten sehen Sie eine grafische Darstellung der KL-Divergenzwerte, wenn sowohl der Mittelwert $ \ mu_2 $ als auch die Standardabweichung $ \ sigma_2 $ geändert werden.
import numpy as np
import matplotlib.pyplot as plt
#Normalverteilung
def gaussian1d(x,μ,σ):
y = 1 / ( np.sqrt(2*np.pi* σ**2 ) ) * np.exp( - ( x - μ )**2 / ( 2 * σ ** 2 ) )
return y
#Normalverteilung KL Divergenz
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
A = np.log(σ2/σ1)
B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
C = -1/2
y = A + B + C
return y
# KL divergence
def KLdivergence(p,q,dx):
KL=np.sum(p * np.log(p/q)) * dx
return KL
def Motion(event):
global cx,cy,cxid,cyid
xp = event.xdata
yp = event.ydata
if (xp is not None) and (yp is not None):
gca = event.inaxes
if gca is axs[0]:
cxid,cx = find_nearest(x,xp)
cyid,cy = find_nearest(y,yp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[1].set_data(x,Z[:,cyid])
lns[2].set_data(y,Z[cxid,:])
lnhs[0].set_ydata([cy,cy])
lnvs[0].set_xdata([cx,cx])
lnvs[1].set_xdata([cx,cx])
lnvs[2].set_xdata([cy,cy])
if gca is axs[2]:
cxid,cx = find_nearest(x,xp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[2].set_data(y,Z[cxid,:])
lnvs[0].set_xdata([cx,cx])
lnvs[1].set_xdata([cx,cx])
if gca is axs[3]:
cyid,cy = find_nearest(y,xp)
lns[0].set_data(G_x,Qx[:,cxid,cyid])
lns[1].set_data(x,Z[:,cyid])
lnhs[0].set_ydata([cy,cy])
lnvs[2].set_xdata([cy,cy])
axs[1].set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)
axs[0].set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)
plt.draw()
def find_nearest(array, values):
id = np.abs(array-values).argmin()
return id,array[id]
#Kerbe x
G_dx = 0.01
#Bereich von x
G_xlm = [-4,4]
#x-Koordinate
G_x = np.arange(G_xlm[0],G_xlm[1]+G_dx,G_dx)
#Anzahl von x
G_n = len(G_x)
# p(x)Durchschnitt μ1
μ1 = 0
# p(x)Standardabweichung σ1
σ1 = 1
# p(x)
px = gaussian1d(G_x,μ1,σ1)
# q(x)Durchschnitt μ2
μ_lim = [-2,2]
μ_dx = 0.1
μ_x = np.arange(μ_lim[0],μ_lim[1]+μ_dx,μ_dx)
μ_n = len(μ_x)
# q(x)Standardabweichung σ2
σ_lim = [0.5,4]
σ_dx = 0.05
σ_x = np.arange(σ_lim[0],σ_lim[1]+σ_dx,σ_dx)
σ_n = len(σ_x)
#KL-Divergenz
KL = np.zeros([μ_n,σ_n])
# q(x)
Qx = np.zeros([G_n,μ_n,σ_n])
for i,μ2 in enumerate(μ_x):
for j,σ2 in enumerate(σ_x):
KL[i,j] = gaussian1d_KLdivergence(μ1,σ1,μ2,σ2)
Qx[:,i,j] = gaussian1d(G_x,μ2,σ2)
x = μ_x
y = σ_x
X,Y = np.meshgrid(x,y)
Z = KL
cxid = 0
cyid = 0
cx = x[cxid]
cy = y[cyid]
xlm = [ x[0], x[-1] ]
ylm = [ y[0], y[-1] ]
axs = []
ims = []
lns = []
lnvs = []
lnhs = []
# figure
#----------------
plt.close('all')
plt.figure(figsize=(8,8))
#Standardfarbe
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']
#Schriftgröße
plt.rcParams["font.size"] = 16
#Linienbreite
plt.rcParams['lines.linewidth'] = 2
#Machen Sie den Gitterlinienstil zu einer gepunkteten Linie
plt.rcParams["grid.linestyle"] = '--'
#Beseitigen Sie beim Zeichnen die Bereichsränder
plt.rcParams['axes.xmargin'] = 0.
# ax1
#----------------
ax = plt.subplot(2,2,1)
Interval = np.arange(0,8,0.1)
plt.plot(μ1,σ1,'rx',label='$(μ_1,σ_1)=(0,1)$')
im = plt.contourf(X,Y,Z.T,Interval,cmap='hot')
lnv= plt.axvline(x=cx,color='w',linestyle='--',linewidth=1)
lnh= plt.axhline(y=cy,color='w',linestyle='--',linewidth=1)
ax.set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)
plt.xlabel('μ')
plt.ylabel('σ')
axs.append(ax)
lnhs.append(lnh)
lnvs.append(lnv)
ims.append(im)
# ax2
#----------------
ax = plt.subplot(2,2,2)
plt.plot(G_x,px,label='$p(x)$')
ln, = plt.plot(G_x,Qx[:,cxid,cyid],color=clr[1],label='$q(x)$')
plt.legend(prop={'size': 10})
ax.set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)
axs.append(ax)
lns.append(ln)
plt.grid()
# ax3
#----------------
ax = plt.subplot(2,2,3)
ln,=plt.plot(x,Z[:,cyid])
lnv= plt.axvline(x=cx,color='k',linestyle='--',linewidth=1)
plt.ylim([0,np.max(Z)])
plt.grid()
plt.xlabel('μ')
plt.ylabel('KL(p||q)')
lnvs.append(lnv)
axs.append(ax)
lns.append(ln)
# ax4
#----------------
ax = plt.subplot(2,2,4)
ln,=plt.plot(y,Z[cxid,:])
lnv= plt.axvline(x=cy,color='k',linestyle='--',linewidth=1)
plt.ylim([0,np.max(Z)])
plt.xlim([ylm[0],ylm[1]])
plt.grid()
plt.xlabel('σ')
plt.ylabel('KL(p||q)')
lnvs.append(lnv)
axs.append(ax)
lns.append(ln)
plt.tight_layout()
for ax in axs:
plt.axes(ax)
ax.set_aspect(1/ax.get_data_ratio())
plt.connect('motion_notify_event', Motion)
plt.show()