[PYTHON] [Statistik] Multiprocessing von MCMC-Sampling

Dies ist ein Artikel, der versucht, die Scratch-Implementierung von MCMC durch Multiprocessing zu beschleunigen. In dem Artikel des anderen Tages, "[Statistik] Ich werde das Sampling mit der Markov-Ketten-Monte-Carlo-Methode (MCMC) mit Animation erklären.", habe ich die Kette implementiert. Ich hatte es nicht, also hatte ich nur eine Kette, aber ich habe versucht, dies mit mehreren Ketten abzutasten und es als Multi-Prozess auszuführen. Da MCMC für jede Kette unabhängig ist, ist es in Ordnung, den Prozess einfach zu trennen, sodass die Beschleunigung einfach war.

Umgebung

⇒ Da es 2 Kerne hat, kann es nur bis zu 2 Prozesse effektiv beschleunigen ...

Code für diesen Artikel

Der Code wird auf GitHub veröffentlicht.  https://github.com/matsuken92/Qiita_Contents/blob/master/multiprocessing/parallel_MCMC.ipynb

Grundlagen der Mehrfachverarbeitung

Zunächst möchte ich die Bewegung von MultiProcessing mit einem einfachen Prozess sehen.

Erstens ist der Import der Bibliothek. Wir verwenden eine Klasse namens Pool, die mehrere Arbeitsprozesse verwaltet.

from multiprocessing import Pool

Im Moment scheint es ein schwerer Prozess zu sein. Lassen Sie uns also einen Prozess ins Visier nehmen, der viele Schleifen aufweist. Es summiert sich nur, aber es dauert einige Sekunden, um es ungefähr 100000000 Mal zu drehen.

def test_calc(num):
    """Schwere Verarbeitung"""
    _sum = 0
    for i in xrange(num):
        _sum += i
    return _sum

Lassen Sie uns die Geschwindigkeit messen, wenn dieser Prozess zweimal der Reihe nach ausgeführt wird.

#Messen Sie die Zeit, zu der es zweimal hintereinander ausgeführt wird
start = time.time()
_sum = 0
for _ in xrange(2):
    _sum += test_calc(100000000)
end = time.time()
print _sum
print "time: {}".format(end-start)

Es dauerte weniger als 12 Sekunden.

out


9999999900000000
time: 11.6906960011

Führen Sie als Nächstes dieselbe Verarbeitung parallel für zwei Prozesse durch und messen Sie.

#Messen Sie die Zeit, wenn Sie in 2 Prozessen ausgeführt werden
n_worker = 2

pool = Pool(processes=n_worker)

#Liste der Argumente, die an die von den beiden Prozessen ausgeführten Funktionen übergeben werden sollen
args = [100000000] * n_worker

start = time.time() #Messung
result = pool.map(test_calc, args)
end = time.time()   #Messung

print  np.sum(result)
print "time: {}".format(end-start)
pool.close()

Es ist etwas mehr als 6 Sekunden, also fast die Hälfte der Zeit. Ich konnte durch 2 Prozesse beschleunigen: Lachen:

out


9999999900000000
time: 6.28346395493

Anwenden von MultiProcessing auf MCMC-Sampling

Wenden wir dies nun auf die parallele Verarbeitung jeder Kette von MCMC-Abtastungen an. Wie immer müssen Sie zuerst die Bibliothek importieren.

import numpy as np
import numpy.random as rd
import scipy.stats as st
import copy, time, os
from datetime import datetime as dt

from multiprocessing import Pool

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", palette="muted", color_codes=True)

Die Funktion "P (・)" ist der Ziel-Posterior-Verteilungskern. Hier verwenden wir einen zweidimensionalen Normalverteilungskern.

#Wahrscheinlichkeitsfunktion minus Normalisierungskonstante
def P(x1, x2, b):
    assert np.abs(b) < 1
    return np.exp(-0.5*(x1**2 - 2*b*x1*x2 + x2**2))

Die Parameter der vorgeschlagenen Verteilung sind als global definiert. Es definiert auch eine "Jetzt (・)" - Funktion zur Zeitmessung. Eine Funktion, die eine Zeichenfolge der aktuellen Zeit anzeigt.

# global parameters
b = 0.5
delta = 1

def now():
    return  dt.strftime(dt.now(), '%H:%M:%S')

Das Folgende ist die Funktion, die die Abtastung durchführt. Probenahme bis die angegebene Anzahl von Proben erreicht ist. Es ist fast dasselbe wie Letztes Mal, außer dass es funktionalisiert ist und der Zeitmesscode hinzugefügt wird. Der Schlüssel ist, diese Funktion parallel auszuführen. Da es für jede Kette abgetastet wird, kann es sich unabhängig zwischen Prozessen bewegen, sodass keine Kommunikation zwischen Prozessen erforderlich ist und es sich einfach anfühlt.

def exec_sampling(n_samples):
    global b, delta
    rd.seed(int(time.time())+os.getpid())
    pid = os.getpid()
    start = time.time()
    start_time = now()
    
    #initial state
    sampling_result = []
    current = np.array([5, 5])
    sampling_result.append(current)
    cnt = 1
    while cnt < n_samples:
        # rv from proposal distribution(Normal Dist: N(0, delta) )
        next = current + rd.normal(0, delta, size=2)
        r = P(next[0], next[1], b)/P(current[0], current[1], b)

        if r > 1 or r > rd.uniform(0, 1):
            # 0-Wenn die einheitliche Zufallszahl 1 größer als r ist, wird der Zustand aktualisiert.
            current = copy.copy(next)
            sampling_result.append(current)
            cnt += 1
            
    end = time.time()    
    end_time = now()

    #Anzeige der erforderlichen Zeit für jede Kette
    print "PID:{}, exec time: {}, {}-{}".format(pid, end-start, start_time, end_time)
    return sampling_result

Die folgenden drei Funktionen "draw_scatter ()", "draw_traceplot ()" und "remove_burn_in_samples ()" sind Funktionen, die das Stichprobenergebnis verarbeiten.

def draw_scatter(sample, alpha=0.3):
    """Zeichnen Sie ein Streudiagramm der Stichprobenergebnisse"""
    plt.figure(figsize=(9,9))
    plt.scatter(sample[:,0], sample[:,1], alpha=alpha)
    plt.title("Scatter plot of 2-dim normal random variable with MCMC. sample size:{}".format(len(sample)))
    plt.show()
    
def draw_traceplot(sample):
    """Zeichnen Sie ein Trace-Diagramm des Stichprobenergebnisses"""
    assert sample.shape[1] == 2
    
    plt.figure(figsize=(15, 6))
    
    for i in range(2):
        plt.subplot(2, 1, i+1)
        plt.xlim(0, len(sample[:,i]))
        plt.plot(sample[:,i], lw=0.05)
        if i == 0:
            order = "1st"
        else:
            order = "2nd"
        plt.title("Traceplot of {} parameter.".format(order))
    
    plt.show()

def remove_burn_in_samples(total_sampling_result, burn_in_rate=0.2):
    """Burn-Schließen Sie das Beispiel des in in angegebenen Abschnitts aus."""
    adjust_burn_in_result = []
    for i in xrange(len(total_sampling_result)):
        idx = int(len(total_sampling_result[i])*burn_in_rate)
        adjust_burn_in_result.extend(total_sampling_result[i][idx:])
    return np.array(adjust_burn_in_result)

Nachfolgend sind die Funktionen aufgeführt, die die Parallelverarbeitung durchführen. Wenn Sie genau hinschauen, können Sie sehen, dass es praktisch dasselbe ist wie das erste einfache Beispiel.

def parallel_exec(n_samples, n_chain, burn_in_rate=0.2):
    """Ausführung der Parallelverarbeitung"""

    #Berechnen Sie die Probengröße pro Kette
    n_samples_per_chain = n_samples / float(n_chain)
    print "Making {} samples per {} chain. Burn-in rate:{}".format(n_samples_per_chain, n_chain, burn_in_rate)

    #Erstellen eines Pool-Objekts
    pool = Pool(processes=n_chain)

    #Generierung von Argumenten für die Ausführung
    n_trials_per_process = [n_samples_per_chain] * n_chain

    #Ausführung der Parallelverarbeitung
    start = time.time() #Messung
    total_sampling_result = pool.map(exec_sampling, n_trials_per_process)
    end = time.time()   #Messung

    #Anzeige der insgesamt benötigten Zeit
    print "total exec time: {}".format(end-start)

    # Drawing scatter plot
    adjusted_samples = remove_burn_in_samples(total_sampling_result)
    draw_scatter(adjusted_samples, alpha=0.01)
    draw_traceplot(adjusted_samples)
    pool.close()

Nun sehen wir uns den tatsächlichen Effekt an. Die Anzahl der Proben beträgt 1.000.000, und die Anzahl der Ketten wird gemessen, wenn es 2 ist und wenn es 1 ist.

Anzahl der Proben: 1.000.000, Anzahl der Ketten: 2

#Parameter: n_samples = 1000000, n_chain = 2
parallel_exec(1000000, 2)

Die Probenahme dauert insgesamt weniger als 19 Sekunden, ungefähr 12 Sekunden pro Arbeitsprozess.

out


Making 500000.0 samples per 2 chain. Burn-in rate:0.2
total exec time: 18.6980280876
PID:2374, exec time: 12.0037689209, 20:53:41-20:53:53
PID:2373, exec time: 11.9927477837, 20:53:41-20:53:53

scatter_chain2.png

traceplot_chain2.png

Anzahl der Proben: 1.000.000, Anzahl der Ketten: 1

#Parameter: n_samples = 1000000, n_chain = 1
parallel_exec(1000000, 1)

Das Ausführen in einem Worker-Prozess dauerte weniger als 33 Sekunden. Wenn Sie es also in zwei Prozessen ausführen, können Sie sehen, dass es 1,7-mal schneller ausgeführt wird: zufrieden:

out


Making 1000000.0 samples per 1 chain. Burn-in rate:0.2
total exec time: 32.683218956
PID:2377, exec time: 24.7304420471, 20:54:07-20:54:31

scatter_chain1.png

traceplot_chain1.png

Referenz

Python-Dokumentation (2.7ja1) 16.6. Multiprocessing - Prozessbasierte Schnittstelle für die parallele Verarbeitung  http://docs.python.jp/2.7/library/multiprocessing.html

Hochleistungs-Python (O'Reilly)  https://www.oreilly.co.jp/books/9784873117409/

Recommended Posts

[Statistik] Multiprocessing von MCMC-Sampling
Üben Sie typische statistische Methoden (1)