[PYTHON] Divergence KL entre les distributions normales

introduction

Je n'étais pas sûr de la divergence KL qui apparaît dans l'algorithme EM, donc j'aimerais obtenir une image en trouvant la divergence KL entre les distributions normales.

Divergence KL

La divergence Kullback-Leibler (divergence KL, information KL) est une mesure de la similitude des deux distributions de probabilité. La définition est la suivante.

KL(p||q) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx

Il y a deux caractéristiques importantes. Le premier est qu'il sera égal à 0 pour la même distribution de probabilité.

KL(p||p) = \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{p(x)}dx
         = \int_{-\infty}^{\infty}p(x)\ln(1)dx
         = 0

La seconde est que ce sera toujours une valeur positive comprenant 0, et plus les distributions de probabilité sont dissemblables, plus la valeur est élevée. Regardons ces caractéristiques en utilisant un exemple de distribution normale.

distribution normale

Les fonctions de densité de probabilité p (x) et q (x) de la distribution normale sont définies comme suit.

p(x) = N(\mu_1,\sigma_1^2) = \frac{1}{\sqrt{2\pi\sigma_1^2}} \exp\left(-\frac{(x-\mu_1)^2}{2\sigma_1^2}\right) \\
q(x) = N(\mu_2,\sigma_2^2) = \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{(x-\mu_2)^2}{2\sigma_2^2}\right)

Divergence KL entre les distributions normales

Trouvez la divergence KL entre les deux distributions normales ci-dessus. Le calcul est omis.

\begin{eqnarray}

KL(p||q)&=& \int_{-\infty}^{\infty}p(x)\ln \frac{p(x)}{q(x)}dx \\
        &=& \cdots \\
        &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}
\end{eqnarray}

Puisqu'il est difficile de comprendre s'il y a quatre variables, soit $ p (x) $ la distribution normale standard $ N (0,1) $ avec une moyenne de 0 et une variance de 1.

p(x) =N(0,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{x^2}{2}\right) 

Lorsque la moyenne est une variable

Tout d'abord, définissez l'écart type de $ q (x) $ sur $ \ sigma_2 $ et ne définissez que la moyenne $ \ mu_2 $ comme variables.

q(x) =N(\mu_2,1)= \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x-\mu_2)^2}{2}\right) 

La divergence KL à ce moment est

\begin{eqnarray}

KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
         &=& \ln\left(\frac{1}{1}\right) + \frac{1^2+(\mu_1-0)^2}{2*1^2} - \frac{1}{2} \\
         &=& \frac{\mu_2^2}{2}
\end{eqnarray}

Ce sera. \mu_2La valeur de-Distribution de probabilité lors de l'augmentation de 1 de 4 à 4$q(x)Et divergence KLKL(p||q)$La valeur de est la suivante.

KL_μ.gif

La ligne orange à gauche est $ q (x) $ lorsque la moyenne $ \ mu_2 $ est modifiée. La figure de droite est celle où la moyenne $ \ mu_2 $ est prise sur l'axe des x. La ligne bleue est la solution analytique et le point orange est la valeur de divergence KL actuelle. Il a été confirmé que la divergence KL devient 0 lorsque $ p (x) $ et $ q (x) $ correspondent exactement, et augmente à mesure que la distance augmente.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

#distribution normale
def gaussian1d(x,μ,σ):
    y = 1 / ( np.sqrt(2*np.pi* σ**2 ) )  * np.exp( - ( x - μ )**2  / ( 2 * σ ** 2 ) )
    return y

#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
    A = np.log(σ2/σ1)
    B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
    C = -1/2
    y = A + B + C
    return y

# KL divergence
def KLdivergence(p,q,dx):
    KL=np.sum(p * np.log(p/q)) * dx
    return KL

#Encoche x
dx  = 0.01

#Gamme de x
xlm = [-6,6]

#coordonnée x
x   = np.arange(xlm[0],xlm[1]+dx,dx)

#Nombre de x
x_n   = len(x)

# Case 1
# p(x) = N(0,1)
# q(x) = N(μ,1)

# p(x)Moyenne μ1
μ1   = 0
# p(x)Écart type σ1
σ1   = 1  

# p(x)
px   = gaussian1d(x,μ1,σ1)

# q(x)Écart type σ2
σ2   = 1

# q(x)Moyenne μ2
U2   = np.arange(-4,5,1)

U2_n = len(U2)

# q(x)
Qx   = np.zeros([x_n,U2_n])

#Divergence KL
KL_U2  = np.zeros(U2_n)

for i,μ2 in enumerate(U2):
    qx        = gaussian1d(x,μ2,σ2)
    Qx[:,i]   = qx
    KL_U2[i]  = KLdivergence(px,qx,dx)


#Portée de la solution analytique
U2_exc    = np.arange(-4,4.1,0.1)

#Solution analytique
KL_U2_exc = gaussian1d_KLdivergence(μ1,σ1,U2_exc,σ2)

#Solution analytique 2
KL_U2_exc2 = U2_exc**2 / 2

#
# plot
#

# figure
fig = plt.figure(figsize=(8,4))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']

# axis 1 
#-----------------------
#Diagramme de distribution normale
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')       
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')       
#Guide d'utilisation
plt.legend(loc=1,prop={'size': 13})

plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')

# axis 2
#-----------------------
#Divergence KL
ax2 = plt.subplot(1,2,2)
#Solution analytique
plt.plot(U2_exc,KL_U2_exc,label='Analytical')
#Calcul
point, = ax2.plot([],'o',label='Numerical')

#Guide d'utilisation
# plt.legend(loc=1,prop={'size': 15})

plt.xlim([U2[0],U2[-1]])
plt.xlabel('$\mu$')
plt.ylabel('$KL(p||q)$')

plt.tight_layout()

#Paramètres communs pour les axes
for a in [ax,ax2]:
    plt.axes(a)
    plt.grid()
    #Dans un carré
    plt.gca().set_aspect(1/plt.gca().get_data_ratio())

#mise à jour
def update(i):
    #ligne
    line.set_data(x,Qx[:,i])
    #point
    point.set_data(U2[i],KL_U2[i])

    #Titre
    ax.set_title("$\mu_2=%.1f$" % U2[i],fontsize=15)
    ax2.set_title('$KL(p||q)=%.1f$' % KL_U2[i],fontsize=15)

#animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=U2_n)
# plt.show()
# ani.save("KL_μ.gif", writer="imagemagick")

Lorsque l'écart type est une variable

Ensuite, la moyenne $ \ mu_2 $ de $ q (x) $ est mise à 0, et seul l'écart type $ \ sigma_2 $ est défini comme une variable.

q(x) =N(0,\sigma^2_2)= \frac{1}{\sqrt{2\pi\sigma_2^2}} \exp\left(-\frac{x^2}{2\sigma_2^2}\right)

La divergence KL à ce moment est

\begin{eqnarray}

KL(p||q) &=& \ln\left(\frac{\sigma_2}{\sigma_1}\right) + \frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \\
         &=& \ln\left(\frac{\sigma_2}{1}\right) + \frac{1^2}{2\sigma_2^2} - \frac{1}{2} \\
         &=& \ln\left(\sigma_2\right) + \frac{1}{2\sigma_2^2} - \frac{1}{2} \\
\end{eqnarray}

Ce sera. \sigma_2Valeur de 0.Distribution de probabilité lors du passage de 5 à 4$q(x)Et divergence KLKL(p||q)$La valeur de est la suivante.

KL_σ.gif

Comme précédemment, le changement de divergence KL est devenu 0 lorsque les distributions de probabilité correspondaient, et a augmenté à mesure que la forme changeait.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

#distribution normale
def gaussian1d(x,μ,σ):
    y = 1 / ( np.sqrt(2*np.pi* σ**2 ) )  * np.exp( - ( x - μ )**2  / ( 2 * σ ** 2 ) )
    return y

#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
    A = np.log(σ2/σ1)
    B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
    C = -1/2
    y = A + B + C
    return y

# KL divergence
def KLdivergence(p,q,dx):
    KL=np.sum(p * np.log(p/q)) * dx
    return KL

#Encoche x
dx  = 0.01

#Gamme de x
xlm = [-6,6]

#coordonnée x
x   = np.arange(xlm[0],xlm[1]+dx,dx)

#Nombre de x
x_n   = len(x)

# Case 2
# p(x) = N(0,1)
# q(x) = N(0,σ**2)

# p(x)Moyenne μ1
μ1   = 0
# p(x)Écart type σ1
σ1   = 1  

# p(x)
px   = gaussian1d(x,μ1,σ1)

# q(x)Moyenne μ2
μ2   = 0

# q(x)Écart type σ2
S2   = np.hstack([ np.arange(0.5,1,0.1),np.arange(1,2,0.2),np.arange(2,4.5,0.5) ])

S2_n = len(S2)

# q(x)
Qx   = np.zeros([x_n,S2_n])

#Divergence KL
KL_S2  = np.zeros(S2_n)

for i,σ2 in enumerate(S2):
    qx        = gaussian1d(x,μ2,σ2)
    Qx[:,i]   = qx
    KL_S2[i]  = KLdivergence(px,qx,dx)


#Portée de la solution analytique
S2_exc    = np.arange(0.5,4+0.05,0.05)

#Solution analytique
KL_S2_exc = gaussian1d_KLdivergence(μ1,σ1,μ2,S2_exc)

#Solution analytique 2
KL_S2_exc2 = np.log(S2_exc) + 1/(2*S2_exc**2) - 1 / 2

#
# plot
#

# figure
fig = plt.figure(figsize=(8,4))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']

# axis 1 
#-----------------------
#Diagramme de distribution normale
ax = plt.subplot(1,2,1)
# p(x)
plt.plot(x,px,label='$p(x)$')       
# q(x)
line,=plt.plot(x,Qx[:,i],color=clr[1],label='$q(x)$')       
#Guide d'utilisation
plt.legend(loc=1,prop={'size': 13})

plt.ylim([0,0.8])
plt.xticks(np.arange(xlm[0],xlm[1]+1,2))
plt.xlabel('$x$')

# axis 2
#-----------------------
#Divergence KL
ax2 = plt.subplot(1,2,2)
#Solution analytique
plt.plot(S2_exc,KL_S2_exc,label='Analytical')
#Calcul
point, = ax2.plot([],'o',label='Numerical')

#Guide d'utilisation
# plt.legend(loc=1,prop={'size': 15})

plt.xlim([S2[0],S2[-1]])
plt.xlabel('$\sigma$')
plt.ylabel('$KL(p||q)$')

plt.tight_layout()

#Paramètres communs pour les axes
for a in [ax,ax2]:
    plt.axes(a)
    plt.grid()
    #Dans un carré
    plt.gca().set_aspect(1/plt.gca().get_data_ratio())

#mise à jour
def update(i):
    #ligne
    line.set_data(x,Qx[:,i])
    #point
    point.set_data(S2[i],KL_S2[i])

    #Titre
    ax.set_title("$\sigma_2=%.1f$" % S2[i],fontsize=15)
    ax2.set_title('$KL(p||q)=%.1f$' % KL_S2[i],fontsize=15)

#animation
ani = animation.FuncAnimation(fig, update, interval=1000,frames=S2_n)
plt.show()
# ani.save("KL_σ.gif", writer="imagemagick")

Lorsque la moyenne et l'écart type sont des variables

Voici un graphique des valeurs de divergence KL lorsque la moyenne $ \ mu_2 $ et l'écart type $ \ sigma_2 $ sont modifiés.

## prime ![KL_motion2.gif](https://qiita-image-store.s3.amazonaws.com/0/35426/9b0197d8-1fd6-de81-673f-a8be1be727a4.gif)
import numpy as np
import matplotlib.pyplot as plt

#distribution normale
def gaussian1d(x,μ,σ):
    y = 1 / ( np.sqrt(2*np.pi* σ**2 ) )  * np.exp( - ( x - μ )**2  / ( 2 * σ ** 2 ) )
    return y

#Distribution normale divergence KL
def gaussian1d_KLdivergence(μ1,σ1,μ2,σ2):
    A = np.log(σ2/σ1)
    B = ( σ1**2 + (μ1 - μ2)**2 ) / (2*σ2**2)
    C = -1/2
    y = A + B + C
    return y

# KL divergence
def KLdivergence(p,q,dx):
    KL=np.sum(p * np.log(p/q)) * dx
    return KL

def Motion(event):
    global cx,cy,cxid,cyid
    
    xp = event.xdata
    yp = event.ydata

    if (xp is not None) and (yp is not None):
        gca = event.inaxes

        if gca is axs[0]:
            cxid,cx = find_nearest(x,xp)
            cyid,cy = find_nearest(y,yp)

            lns[0].set_data(G_x,Qx[:,cxid,cyid])
            lns[1].set_data(x,Z[:,cyid])
            lns[2].set_data(y,Z[cxid,:])            
            

            lnhs[0].set_ydata([cy,cy])
            lnvs[0].set_xdata([cx,cx])

            lnvs[1].set_xdata([cx,cx])
            lnvs[2].set_xdata([cy,cy])
            

        if gca is axs[2]:    
            cxid,cx = find_nearest(x,xp)

            lns[0].set_data(G_x,Qx[:,cxid,cyid])
            lns[2].set_data(y,Z[cxid,:])            
            lnvs[0].set_xdata([cx,cx])
            lnvs[1].set_xdata([cx,cx])

        if gca is axs[3]:    
            cyid,cy = find_nearest(y,xp)

            lns[0].set_data(G_x,Qx[:,cxid,cyid])
            lns[1].set_data(x,Z[:,cyid])
            lnhs[0].set_ydata([cy,cy])
            lnvs[2].set_xdata([cy,cy])
            
    axs[1].set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)
    axs[0].set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)

    plt.draw()

def find_nearest(array, values):
    id = np.abs(array-values).argmin()
    return id,array[id]

#Encoche x
G_dx  = 0.01
#Gamme de x
G_xlm = [-4,4]
#coordonnée x
G_x   = np.arange(G_xlm[0],G_xlm[1]+G_dx,G_dx)
#Nombre de x
G_n   = len(G_x)

# p(x)Moyenne μ1
μ1   = 0
# p(x)Écart type σ1
σ1   = 1  
# p(x)
px   = gaussian1d(G_x,μ1,σ1)

# q(x)Moyenne μ2
μ_lim = [-2,2]
μ_dx  = 0.1
μ_x   = np.arange(μ_lim[0],μ_lim[1]+μ_dx,μ_dx)
μ_n   = len(μ_x)

# q(x)Écart type σ2
σ_lim = [0.5,4]
σ_dx  = 0.05
σ_x   = np.arange(σ_lim[0],σ_lim[1]+σ_dx,σ_dx)
σ_n   = len(σ_x)

#Divergence KL
KL   = np.zeros([μ_n,σ_n])
# q(x)
Qx   = np.zeros([G_n,μ_n,σ_n])

for i,μ2 in enumerate(μ_x):
    for j,σ2 in enumerate(σ_x):
        KL[i,j]   = gaussian1d_KLdivergence(μ1,σ1,μ2,σ2)
        Qx[:,i,j] = gaussian1d(G_x,μ2,σ2)

x   = μ_x
y   = σ_x

X,Y = np.meshgrid(x,y)
Z   = KL

cxid  = 0
cyid  = 0

cx    = x[cxid]
cy    = y[cyid]

xlm   = [ x[0], x[-1] ]
ylm   = [ y[0], y[-1] ]

axs   = []
ims   = []
lns   = []
lnvs  = []
lnhs  = []

# figure
#----------------
plt.close('all')
plt.figure(figsize=(8,8))
#Couleur par défaut
clr=plt.rcParams['axes.prop_cycle'].by_key()['color']

#taille de police
plt.rcParams["font.size"] = 16
#Largeur de ligne
plt.rcParams['lines.linewidth'] = 2
#Faire du style de ligne de la grille une ligne pointillée
plt.rcParams["grid.linestyle"] = '--'

#Éliminez les marges de plage lors du traçage
plt.rcParams['axes.xmargin'] = 0.

# ax1
#----------------
ax = plt.subplot(2,2,1)

Interval = np.arange(0,8,0.1)
plt.plot(μ1,σ1,'rx',label='$(μ_1,σ_1)=(0,1)$')
im = plt.contourf(X,Y,Z.T,Interval,cmap='hot')
lnv= plt.axvline(x=cx,color='w',linestyle='--',linewidth=1)
lnh= plt.axhline(y=cy,color='w',linestyle='--',linewidth=1)

ax.set_title('$KL(p||q)=$%.3f' % Z[cxid,cyid],fontsize=15)
plt.xlabel('μ')
plt.ylabel('σ')

axs.append(ax)
lnhs.append(lnh)
lnvs.append(lnv)
ims.append(im)

# ax2
#----------------
ax = plt.subplot(2,2,2)
plt.plot(G_x,px,label='$p(x)$')
ln, = plt.plot(G_x,Qx[:,cxid,cyid],color=clr[1],label='$q(x)$')
plt.legend(prop={'size': 10})
ax.set_title("$\mu_2=%5.2f, \sigma_2=$%5.2f" % (cx,cy),fontsize=15)

axs.append(ax)
lns.append(ln)
plt.grid()

# ax3
#----------------
ax = plt.subplot(2,2,3)
ln,=plt.plot(x,Z[:,cyid])
lnv= plt.axvline(x=cx,color='k',linestyle='--',linewidth=1)

plt.ylim([0,np.max(Z)])
plt.grid()
plt.xlabel('μ')
plt.ylabel('KL(p||q)')

lnvs.append(lnv)
axs.append(ax)
lns.append(ln)

# ax4
#----------------
ax = plt.subplot(2,2,4)
ln,=plt.plot(y,Z[cxid,:])

lnv= plt.axvline(x=cy,color='k',linestyle='--',linewidth=1)

plt.ylim([0,np.max(Z)])
plt.xlim([ylm[0],ylm[1]])
plt.grid()

plt.xlabel('σ')
plt.ylabel('KL(p||q)')

lnvs.append(lnv)
axs.append(ax)
lns.append(ln)

plt.tight_layout()

for ax in axs:
    plt.axes(ax)
    ax.set_aspect(1/ax.get_data_ratio())
    
plt.connect('motion_notify_event', Motion)

plt.show()

Recommended Posts

Divergence KL entre les distributions normales
Remarques sur la divergence KL entre les distributions de Poisson