[PYTHON] [Statistics] Multiprocessing of MCMC sampling

This is an article that tries to speed up the scratch implementation of MCMC by making it Multiprocessing. In the other day's article, "[Statistics] Markov Chain Monte Carlo (MCMC) sampling with animation.", we implemented chain. I didn't have it, so I only had one chain, but I tried to sample this with multiple chains and execute it as a multi-process. Since MCMC is independent for each chain, it is OK to simply separate the process, so it was easy to speed up.

environment

⇒Since it has 2 cores, it can only effectively accelerate up to 2 processes ...

Code for this article

The code is posted on GitHub.  https://github.com/matsuken92/Qiita_Contents/blob/master/multiprocessing/parallel_MCMC.ipynb

Basics of MultiProcessing

First of all, I would like to see the movement of MultiProcessing with a simple process.

First, import the library. We use a class called Pool that manages multiple worker processes.

from multiprocessing import Pool

For the time being, it seems to be a heavy process, so let's target a process that loops a lot. It only adds, but it takes a few seconds to turn it about 100000000 times.

def test_calc(num):
    """Heavy processing"""
    _sum = 0
    for i in xrange(num):
        _sum += i
    return _sum

Let's measure the speed when this process is executed twice in order.

#Measure the time when it is executed twice sequentially
start = time.time()
_sum = 0
for _ in xrange(2):
    _sum += test_calc(100000000)
end = time.time()
print _sum
print "time: {}".format(end-start)

It took less than 12 seconds.

out


9999999900000000
time: 11.6906960011

Next, perform the same process in parallel for two processes and measure.

#Measure the time when executed in 2 processes
n_worker = 2

pool = Pool(processes=n_worker)

#Argument list to pass to the function executed by the two processes
args = [100000000] * n_worker

start = time.time() #measurement
result = pool.map(test_calc, args)
end = time.time()   #measurement

print  np.sum(result)
print "time: {}".format(end-start)
pool.close()

It's a little over 6 seconds, so it's almost half the time. I was able to speed up by 2 processes: laughing:

out


9999999900000000
time: 6.28346395493

Applying MultiProcessing to MCMC sampling

Now, let's apply this to processing each chain of MCMC sampling in parallel. As always, the first thing to do is import the library.

import numpy as np
import numpy.random as rd
import scipy.stats as st
import copy, time, os
from datetime import datetime as dt

from multiprocessing import Pool

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid", palette="muted", color_codes=True)

The function P (・) is the target posterior distribution kernel. Here we are using a kernel with a two-dimensional normal distribution.

#Probability function minus normalization constants
def P(x1, x2, b):
    assert np.abs(b) < 1
    return np.exp(-0.5*(x1**2 - 2*b*x1*x2 + x2**2))

The parameters of the proposed distribution are defined as global. It also defines the now (・) function for time measurement. A function that displays a string of the current time.

# global parameters
b = 0.5
delta = 1

def now():
    return  dt.strftime(dt.now(), '%H:%M:%S')

The following is the function that performs sampling. Sampling is performed until the specified number of samples is reached. It is almost the same as Last time except that it is functionalized and the time measurement code is added. The key is to execute this function in parallel. Since it is sampling for each chain, it can move independently between processes, so it feels easy because there is no need for interprocess communication.

def exec_sampling(n_samples):
    global b, delta
    rd.seed(int(time.time())+os.getpid())
    pid = os.getpid()
    start = time.time()
    start_time = now()
    
    #initial state
    sampling_result = []
    current = np.array([5, 5])
    sampling_result.append(current)
    cnt = 1
    while cnt < n_samples:
        # rv from proposal distribution(Normal Dist: N(0, delta) )
        next = current + rd.normal(0, delta, size=2)
        r = P(next[0], next[1], b)/P(current[0], current[1], b)

        if r > 1 or r > rd.uniform(0, 1):
            # 0-When the uniform random number of 1 is larger than r, the state is updated.
            current = copy.copy(next)
            sampling_result.append(current)
            cnt += 1
            
    end = time.time()    
    end_time = now()

    #Display of required time for each chain
    print "PID:{}, exec time: {}, {}-{}".format(pid, end-start, start_time, end_time)
    return sampling_result

The following three functions draw_scatter (), draw_traceplot (), and remove_burn_in_samples () are functions that process sampling results.

def draw_scatter(sample, alpha=0.3):
    """Draw a scatter plot of sampling results"""
    plt.figure(figsize=(9,9))
    plt.scatter(sample[:,0], sample[:,1], alpha=alpha)
    plt.title("Scatter plot of 2-dim normal random variable with MCMC. sample size:{}".format(len(sample)))
    plt.show()
    
def draw_traceplot(sample):
    """Draw a trace plot of sampling results"""
    assert sample.shape[1] == 2
    
    plt.figure(figsize=(15, 6))
    
    for i in range(2):
        plt.subplot(2, 1, i+1)
        plt.xlim(0, len(sample[:,i]))
        plt.plot(sample[:,i], lw=0.05)
        if i == 0:
            order = "1st"
        else:
            order = "2nd"
        plt.title("Traceplot of {} parameter.".format(order))
    
    plt.show()

def remove_burn_in_samples(total_sampling_result, burn_in_rate=0.2):
    """Burn-Exclude the sample of the section specified in in."""
    adjust_burn_in_result = []
    for i in xrange(len(total_sampling_result)):
        idx = int(len(total_sampling_result[i])*burn_in_rate)
        adjust_burn_in_result.extend(total_sampling_result[i][idx:])
    return np.array(adjust_burn_in_result)

The following is a function that performs parallel processing. If you look closely, you can see that it is virtually the same as the first simple example.

def parallel_exec(n_samples, n_chain, burn_in_rate=0.2):
    """Execution of parallel processing"""

    #Calculate sample size per chain
    n_samples_per_chain = n_samples / float(n_chain)
    print "Making {} samples per {} chain. Burn-in rate:{}".format(n_samples_per_chain, n_chain, burn_in_rate)

    #Creating a Pool object
    pool = Pool(processes=n_chain)

    #Generate arguments for execution
    n_trials_per_process = [n_samples_per_chain] * n_chain

    #Execution of parallel processing
    start = time.time() #measurement
    total_sampling_result = pool.map(exec_sampling, n_trials_per_process)
    end = time.time()   #measurement

    #Display of total required time
    print "total exec time: {}".format(end-start)

    # Drawing scatter plot
    adjusted_samples = remove_burn_in_samples(total_sampling_result)
    draw_scatter(adjusted_samples, alpha=0.01)
    draw_traceplot(adjusted_samples)
    pool.close()

Now let's see the actual effect. The number of samplings is 1,000,000, and the cases where the number of chains is 2 and 1 are measured.

Sampling number: 1,000,000, chain number: 2

#parameter: n_samples = 1000000, n_chain = 2
parallel_exec(1000000, 2)

Sampling is completed in less than 19 seconds in total, about 12 seconds per worker process.

out


Making 500000.0 samples per 2 chain. Burn-in rate:0.2
total exec time: 18.6980280876
PID:2374, exec time: 12.0037689209, 20:53:41-20:53:53
PID:2373, exec time: 11.9927477837, 20:53:41-20:53:53

scatter_chain2.png

traceplot_chain2.png

Sampling number: 1,000,000, chain number: 1

#parameter: n_samples = 1000000, n_chain = 1
parallel_exec(1000000, 1)

Running in one worker process took less than 33 seconds. So if you run it in two processes, you can see that it runs 1.7 times faster: satisfied:

out


Making 1000000.0 samples per 1 chain. Burn-in rate:0.2
total exec time: 32.683218956
PID:2377, exec time: 24.7304420471, 20:54:07-20:54:31

scatter_chain1.png

traceplot_chain1.png

reference

Python Documentation (2.7ja1) 16.6. Multiprocessing — Process-based “parallel processing” interface  http://docs.python.jp/2.7/library/multiprocessing.html

High Performance Python (O'Reilly)  https://www.oreilly.co.jp/books/9784873117409/

Recommended Posts

[Statistics] Multiprocessing of MCMC sampling
Practice typical methods of statistics (1)