PRML Chapter 8 Product Sum Algorithm Python Implementation

Chapter 8 describes the graphical model. A graphical model is a method of graphically expressing relationships such as random variables and model parameters. Models such as linear regression in PRML can generally be represented using graphical models. I implemented a product-sum algorithm that can be applied to a model that can be represented by a graphical model to remove image noise. This method seems to be a general form of the method applied to the hidden Markov model introduced in Chapter 13 of PRML.

Image noise removal

Here, consider a binary image (white: 1, black: -1) in which the pixels in the image have either a value of -1,1. Given an image with inverted values at some pixels, like the one below Restores a noise-free image that looks like this:

Markov random field

A Markov random field (Markov network, undirected graphical model), which is a kind of graphical model, is used as a model for removing image noise.

The figure below is a schematic representation of the Markov random field considered this time. The ** $ y $ node is a random variable that represents the pixel value of the image we observed with noise, and the $ x $ node is a random variable that represents the pixel value of the noise-free image that we want to restore **. In a noise-free image, adjacent pixels are considered to have a strong correlation, so even in this model, adjacent nodes are connected by a line in $ x $. Also, because the percentage of noise is relatively small, there is a strong correlation between the value of a pixel in a noisy image and the corresponding pixel value in a noisy image. For example, if a pixel in a noisy image has a pixel value of 1, then the pixel adjacent to that pixel is likely to be 1, and the corresponding pixel value in a noisy image may also be 1. high. It means that the nodes connected by ** lines tend to have the same pixel value **.

Expressing this as a mathematical formula,

p({\bf x},{\bf y}) = {1\over Z}\exp\left\{-E({\bf x},{\bf y})\right\}

However, using $ i, j $ as the index to represent the node,

E({\bf x},{\bf y}) = -\alpha\sum_ix_iy_i-\beta\sum_{Adjacent i and j}x_ix_j

And $ Z $ is a normalization constant.

Given an image with noisey_iIs a concrete observation value, so if you substitute themp({\bf x}|{\bf y})It can be obtained. thisp({\bf x}|{\bf y})The value of{\bf x}Is the restored image.{\bf x}The total number of states that can be taken2^{Number of nodes}So all{\bf x}It is not realistic to try the pattern of. There is also an algorithm called iterative conditional mode as a method for estimating the restored image, but here we restore it using the product-sum algorithm that seems to give a better solution.

Product sum algorithm

This algorithm is a method to calculate the probability of the value that a node can take in a random variable node on a graphical model. Probabilities are calculated by exchanging something like a non-normalized probability called a message from nodes connected by a line.

Message from $ y_i $ to $ x_i $

m_{y_i\to x_i} = \exp(\alpha x_i y_i)

However, $ y_i $ here is not a random variable but an observed value.

Message from $ x_j $ to $ x_i $

m_{x_j\to x_i} = \sum_{x_j}\exp(\beta x_ix_j)f(x_j)

However, $ f (x_j) $ is the product of messages from nodes other than node $ i $ adjacent to node $ j $. To calculate $ f (x_j) $, we need a message from another node to node $ i $, but since the graphical model we are dealing with this time is a model with a loop, first initialize the message with a certain value, and then Send the above two messages to estimate $ {\ bf x} $.


import itertools
import numpy as np
from scipy.misc import imread, imsave

ORIGINAL_IMAGE = "qiita_binary.png "
NOISY_IMAGE = "qiita_noise.png "
DENOISED_IMAGE = "qiita_denoised.png "

class Node(object):

    def __init__(self):
        self.neighbors = []
        self.messages = {}
        self.prob = None

        self.alpha = 10.
        self.beta = 5.

    def add_neighbor(self, node):
        add neighboring node

        node : Node
            neighboring node

    def get_neighbors(self):
        get neighbor nodes

        neighbors : list
            list containing neighbor nodes
        return self.neighbors

    def init_messeges(self):
        initialize messages from neighbor nodes
        for neighbor in self.neighbors:
            self.messages[neighbor] = np.ones(shape=(2,)) * 0.5

    def marginalize(self):
        calculate probability
        prob = reduce(lambda x, y: x * y, self.messages.values())
        self.prob = prob / prob.sum()

    def send_message_to(self, node):
        calculate message to be sent to the node

        node : Node
            node to send computed message

        message : np.ndarray (2,)
            message to be sent to the node
        message_from_neighbors = reduce(lambda x, y: x * y, self.messages.values()) / self.messages[node]
        F = np.exp(self.beta * (2 * np.eye(2) - 1))
        message =
        node.messages[self] = message / message.sum()

    def likelihood(self, value):
        calculate likelihood via observation, which is messege to this node

        value : int
            observed value -1 or 1
        assert (value == -1) or (value == 1), "{} is not 1 or -1".format(value)
        message = np.exp(self.alpha * np.array([-value, value]))
        self.messages[self] = message / message.sum()

class MarkovRandomField(object):

    def __init__(self):
        self.nodes = {}

    def add_node(self, location):
        add a new node at the location

        location : tuple
            key to access the node
        self.nodes[location] = Node()

    def get_node(self, location):
        get the node at the location

        location : tuple
            key to access the corresponding node

        node : Node
            the node at the location
        return self.nodes[location]

    def add_edge(self, key1, key2):
        add edge between nodes corresponding to key1 and key2

        key1 : tuple
            The key to access one of the nodes
        key2 : tuple
            The key to access the other node.

    def sum_product_algorithm(self, iter_max=10):
        Perform sum product algorithm
        1. initialize messages
        2. send messages from each node to neighboring nodes
        3. calculate probabilities using the messages

        iter_max : int
            number of maximum iteration
        for node in self.nodes.values():

        for i in xrange(iter_max):
            print i
            for node in self.nodes.values():
                for neighbor in node.get_neighbors():

        for node in self.nodes.values():

def denoise(img, n_iter=20):
    mrf = MarkovRandomField()
    len_x, len_y = img.shape
    X = range(len_x)
    Y = range(len_y)

    for location in itertools.product(X, Y):

    for x, y in itertools.product(X, Y):
        for dx, dy in itertools.permutations(range(2), 2):
                mrf.add_edge((x, y), (x + dx, y + dy))
            except Exception:

    for location in itertools.product(X, Y):
        node = mrf.get_node(location)


    denoised = np.zeros_like(img)
    for location in itertools.product(X, Y):
        node = mrf.get_node(location)
        denoised[location] = 2 * np.argmax(node.prob) - 1

    return denoised

def main():
    img_original = 2 * (imread(ORIGINAL_IMAGE) / 255).astype( - 1
    img_noise = 2 * (imread(NOISY_IMAGE) / 255).astype( - 1

    img_denoised = denoise(img_noise, 10)

    print "error rate before"
    print np.sum((img_original != img_noise).astype(np.float)) / img_noise.size
    print "error rate after"
    print np.sum((img_denoised != img_original).astype(np.float)) / img_noise.size
    imsave(DENOISED_IMAGE, (img_denoised + 1) / 2 * 255)

if __name__ == '__main__':


Image restored by removing noise Terminal output result

error rate before
error rate after

The error rate when compared with the original image is reduced.

At the end

If the purpose is to remove noise, the graph cut method seems to be the most accurate.

Recommended Posts

PRML Chapter 8 Product Sum Algorithm Python Implementation
PRML Chapter 5 Neural Network Python Implementation
PRML Chapter 3 Evidence Approximation Python Implementation
PRML Chapter 4 Bayesian Logistic Regression Python Implementation
PRML Chapter 5 Mixed Density Network Python Implementation
PRML Chapter 9 Mixed Gaussian Distribution Python Implementation
PRML Chapter 14 Conditional Mixed Model Python Implementation
PRML Chapter 10 Variational Gaussian Distribution Python Implementation
PRML Chapter 6 Gaussian Process Regression Python Implementation
PRML Chapter 2 Student's t Distribution Python Implementation
PRML Chapter 1 Bayesian Curve Fitting Python Implementation
PRML Chapter 11 Markov Chain Monte Carlo Python Implementation
PRML Chapter 12 Bayesian Principal Component Analysis Python Implementation
Implemented in Python PRML Chapter 4 Classification by Perceptron Algorithm
PRML Chapter 7 Related Vector Machine Python Implementation for Regression Problems
Explanation and implementation of PRML Chapter 4
Sorting algorithm and implementation in Python
Implementation of Dijkstra's algorithm with python
Python algorithm
PRML Chapter 13 Maximum Likelihood Estimating Python Implementation of Hidden Markov Models
PRML implementation Chapter 3 Linear basis function model
Implemented in Python PRML Chapter 7 Nonlinear SVM
Implemented in Python PRML Chapter 5 Neural Networks
Implemented in Python PRML Chapter 1 Bayesian Inference
Non-recursive implementation of extended Euclidean algorithm (Python)
Python memorandum (algorithm)
Implemented in Python PRML Chapter 1 Polynomial Curve Fitting
Implement PRML algorithm in Python (almost Numpy only)
A * algorithm (Python edition)
Python basic grammar / algorithm
RNN implementation in python
ValueObject implementation in Python
[Python] Cumulative sum ABC186D
Genetic algorithm in python
Algorithm in Python (Bellman-Ford)
Relearn Python (Algorithm I)
Implement sum in Python
[Python] Chapter 01-01 About Python (First Python)
SVM implementation in python
Algorithm in Python (Dijkstra's algorithm)
[Python] Cumulative sum ABC179D
Python Machine Learning Programming Chapter 2 Classification Problems-Machine Learning Algorithm Training Summary