[PYTHON] Deep Kernel Learning with Pyro

Deep Kernel Learning is a combination of deep learning and Gaussian process, and is one of Bayesian deep learning. As a method, a deep kernel is created by using the features output from the deep neural network (DNN) as the input of the kernel in the Gaussian process. The formula is as follows.

k_{deep}(x,x') = k(f(x),f(x'))

Since the Gaussian process is equivalent to a neural network with infinite units, it looks like it was added to the end of DNN. As I tried in the previous article (https://qiita.com/takeajioka/items/f24d58d2b13017ab2b18), it is important to optimize kernel hyperparameters during the Gaussian process. Deep Kernel Learning seems to optimize and learn DNN parameters and kernel hyperparameters at the same time.

Please refer to the following paper for details. [1] Deep Kernel Learning, 2015, Andrew G. Wilson et al.,https://arxiv.org/abs/1511.02222 [2] Stochastic Variational Deep Kernel Learning, 2016, Andrew G. Wilson et al., https://arxiv.org/abs/1611.00336

Try learning MNIST with Deep Kernel Learning

In Pyro, you can easily create a deep kernel by using the gp.kernels.Warping class. There is a Deep Kernel Learning code in Pyro Official Tutorial, so let's learn by referring to it.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import pyro
import pyro.contrib.gp as gp
import pyro.infer as infer

Since MNIST has a large amount of data, we will learn it in a mini-batch. Set the dataset.

batch_size = 100
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, ), (0.5, ))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

First, prepare a normal DNN model.

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Wrap the kernel around it to create a deep kernel.

rbf = gp.kernels.RBF(input_dim=10, lengthscale=torch.ones(10))
deep_kernel = gp.kernels.Warping(rbf, iwarping_fn=CNN())

A sparse approximation is used to reduce the computational cost of the Gaussian process. In the sparse approximation, the inducing point is used, but this time we will use training data for one batch size.

Xu, _ = next(iter(trainloader))
likelihood = gp.likelihoods.MultiClass(num_classes=10)
gpmodule = gp.models.VariationalSparseGP(X=None, y=None, kernel=deep_kernel, Xu=Xu, likelihood=likelihood, latent_shape=torch.Size([10]), num_data=60000)
optimizer = torch.optim.Adam(gpmodule.parameters(), lr=0.01)
elbo = infer.TraceMeanField_ELBO()
loss_fn = elbo.differentiable_loss

Defines a mini-batch learning function.

def train(train_loader, gpmodule, optimizer, loss_fn, epoch):
    total_loss = 0
    for data, target in train_loader:
        gpmodule.set_data(data, target)
        optimizer.zero_grad()
        loss = loss_fn(gpmodule.model, gpmodule.guide)
        loss.backward()
        optimizer.step()
        total_loss += loss
    return total_loss / len(train_loader)

def test(test_loader, gpmodule):
    correct = 0
    for data, target in test_loader:
        f_loc, f_var = gpmodule(data)
        pred = gpmodule.likelihood(f_loc, f_var)
        correct += pred.eq(target).long().sum().item()
    return 100. * correct / len(test_loader.dataset)

Do learning.

import time
losses = []
accuracy = []
epochs = 10
for epoch in range(epochs):
    start_time = time.time()
    loss = train(trainloader, gpmodule, optimizer, loss_fn, epoch)
    losses.append(loss)
    with torch.no_grad():
        acc = test(testloader, gpmodule)
    accuracy.append(acc)
    print("Amount of time spent for epoch {}: {}s\n".format(epoch+1, int(time.time() - start_time)))
print("loss:{:.2f}, accuracy:{}".format(losses[-1],accuracy[-1]))

I was able to learn one epoch in about 30 seconds. The final accuracy was 96.23%. (It seems that it can be up to 99.41% in the official tutorial.) Display the learning curve.

import matplotlib.pyplot as plt
plt.subplot(2,1,1)
plt.plot(losses)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.subplot(2,1,2)
plt.plot(accuracy)
plt.xlabel("epoch")
plt.ylabel("accuracy")

image.png

Let's look at the test image and the predicted output side by side.

data, target = next(iter(testloader))
f_loc, f_var = gpmodule(data)
pred = gpmodule.likelihood(f_loc, f_var)
for i in range(len(data)):
    plt.subplot(1,2,1)
    plt.imshow(data[i].reshape(28, 28))
    plt.subplot(1,2,2)
    plt.bar(range(10), f_loc[:,i].detach(), yerr= f_var[:,i].detach())
    ax = plt.gca()
    ax.set_xticks(range(10))
    plt.xlabel("class")
    plt.savefig('image/figure'+ str(i) +'.png')
    plt.clf()

figure0.png figure12.png figure15.png

The blue bar is the mean and the error bar is the variance. It was found that each of the 10 classes had an output, and the correct class output a high value.

Let's also look at images that are difficult to distinguish. figure36.png figure43.png figure92.png The output is high in multiple classes. Considering the error bars, it seems that there is no significant difference.

At the end

I was able to learn like normal Deep Learning. I think that it is an advantage that normal Deep Learnig does not have that it can output mean value and variance as output. There is also a deep Gaussian process (DGP), which is a stack of Gaussian processes, so I would like to study that as well.

Recommended Posts

Deep Kernel Learning with Pyro
Try deep learning with TensorFlow
Try Deep Learning with FPGA
Generate Pokemon with Deep Learning
Deep Learning
Try Deep Learning with FPGA-Select Cucumbers
Cat breed identification with deep learning
Make ASCII art with deep learning
Try deep learning with TensorFlow Part 2
Solve three-dimensional PDEs with deep learning.
Check squat forms with deep learning
Categorize news articles with deep learning
Forecasting Snack Sales with Deep Learning
Make people smile with Deep Learning
Classify anime faces with deep learning with Chainer
Deep Learning Memorandum
Try Bitcoin Price Forecasting with Deep Learning
Try with Chainer Deep Q Learning --Launch
Start Deep learning
Try deep learning of genomics with Kipoi
Sentiment analysis of tweets with deep learning
Python Deep Learning
Deep learning × Python
The story of doing deep learning with TPU
99.78% accuracy with deep learning by recognizing handwritten hiragana
First Deep Learning ~ Struggle ~
Learning Python with ChemTHEATER 03
"Object-oriented" learning with python
Learning Python with ChemTHEATER 05-1
Nonparametric Bayes with Pyro
Deep Learning from scratch
Kernel Method with Python
Deep learning 1 Practice of deep learning
Deep learning / cross entropy
First Deep Learning ~ Preparation ~
First Deep Learning ~ Solution ~
[AI] Deep Metric Learning
Learning Python with ChemTHEATER 02
I tried deep learning
Learning Python with ChemTHEATER 01
Python: Deep Learning Tuning
Deep learning large-scale technology
Deep learning / softmax function
A story about predicting exchange rates with Deep Learning
Deep learning image analysis starting with Kaggle and Keras
Extract music features with Deep Learning and predict tags
Classify anime faces by sequel / deep learning with Keras
Kernel regression with Numpy only
Deep Learning from scratch 1-3 chapters
Try to build a deep learning / neural network with scratch
[Evangelion] Try to automatically generate Asuka-like lines with Deep Learning
<Course> Deep Learning: Day2 CNN
Ensemble learning summary! !! (With implementation)
Self-build linux kernel with clang
Deep learning image recognition 1 theory
Deep running 2 Tuning of deep learning
Create an environment for "Deep Learning from scratch" with Docker
Reinforcement learning starting with Python
About learning with google colab
Deep learning / LSTM scratch code
Machine learning with Python! Preparation