[PYTHON] Quantum-inspired machine learning with tensor networks

Introduction

This time, I will introduce quantum-inspired machine learning, which has become a little hot in recent years. To be more precise, "quantum inspire" here is "inspired by the method used to simulate quantum systems as efficiently as possible in classical computation". Specifically, it is a tensor network.

background

There has been a long-standing trend of using tensor networks to simulate quantum systems using classical systems. Efficiently compute the state of many-body quantum systems using matrix product states [1], efficiently simulate gate quantum computations with a combination of undirected graphs and tensor network reductions [2], etc. ..

As these methods show, by using a tensor network, the very high-dimensional space of a quantum system is only approximate, but it can be handled by classical calculations. Being high-dimensional leads to high expressiveness in machine learning.

"Quantum-inspired machine learning" is to apply this feature not only to the simulation of quantum systems but also to classical machine learning problems.

Matrix product state (review)

The matrix product state is called MPS and is represented as shown in the figure below [1].

image.png

Each black circle is called a "site", and if it is an N qubit system, N sites will be created. \sigma_iIs"physical index"と呼ばれ、qubitの場合Is0\ (|0\rangle) or 1\ (|1\rangle)を指します。各siteIs通常の2次元行列にphysical indexの次元を加えた3次元行列(tensor)I think.

You can also imagine that each site has a two-dimensional matrix corresponding to physical index = 0 and a two-dimensional matrix corresponding to physical index = 1. If $ \ sigma_i = 0 $ for all $ i $, then the product of the two-dimensional matrices corresponding to the physical index = 0 of all sites can be calculated, and the result is $ in the original quantum state. | 00 ... 0> $ coefficient.

How to do machine learning

Introduced based on typical papers [3] and [4] dealing with classification problems by supervised learning. The overall flow is shown in the figure.

スクリーンショット 2020-10-20 12.04.03.png

First, encode the input data $ x $ into $ \ sigma_i \ (i \ in 0, ..., n-1) $ in the matrix product state diagram.

Takes a tensor contraction between the encoded data and the MPS. Furthermore, the edges between the MPS sites are reduced, but as it is, only one scalar value can be obtained, so it cannot be used for classification. Therefore, MPS has a "label index" in advance to output the value corresponding to the probability that $ x $ belongs to each class. Have the label index on one existing site or one newly added site to keep the label index. This way, the calculation of the reductions of $ \ sigma_i $ and all edges of the MPS will leave a tensor with the same number of discriminant classes and equal elements, so you can enter that value into the loss function.

When learning, update each element of MPS so that the output of the loss function is small. There are two main renewal policies. One is the method adopted in [3], which is an application of the conventional method called DMRG. Repeat while sweeping the update by local optimization with only two adjacent sites as variables. The other is adopted in [4] and uses the error backpropagation method to update all elements of the MPS.

The former method has the advantage that extra dimensions can be dynamically trimmed using SVD when updating. On the other hand, the latter method is compatible with the calculation by the existing DL and automatic differentiation framework, and probably has a high degree of freedom in defining the network structure and loss function.

Implementation

This time, we implemented MNIST learning by the error back propagation method, which was performed in [4]. For the implementation, we used a python module called Tensornetwork developed by the authors.

Tensornetwork is a library that is literally suitable for calculating tensor networks. You can select "tensorflow" and "jax" as the backend. If you select "tensorflow", you can study in combination with the Tensorflow framework. It is convenient to be able to use Tensorflow's automatic differentiation function and built-in functions, but on the other hand, when writing it, most of it is occupied by custom layers, so it is troublesome to write according to the framework and the overhead of the framework itself There are also aspects of concern.

So this time we are using the jax backend. In fact, the study [5] following [4] seems to use the jax backend. jax is also a python framework, roughly like numpy, which supports parallel computing with automatic differentiation, JIT, and vectorization. It might be a good option if you just want to use Tensorflow's high-speed automatic differentiation simply (I think Julia's Flux etc. is in a similar position, and there is a certain demand for it).

My implementation is slightly different from [4] in the following points.

  1. 2x2 average pooling of MNIST image data.
  2. The optimizer is a simple gradient descent method instead of the adam used in [4](the learning rate and the number of Epochs are adjusted accordingly).

Regarding 1., it was difficult to learn with the original size. When reducing the number of pixels, the number of matrices is multiplied by the number of pixels, and as the number of matrices increases, the output value tends to diverge or converge to 0, and the gradient tends to disappear, which is a practical difficulty. there is. I think it depends on the adjustment, but this time I compromised. The authors are also pooling in [5](maybe because the network structure and tasks are somewhat different).

  1. is to simplify the sample implementation. There is also a history that the gradient descent method was not particularly inferior to the result of trying with adam optimizer when I wrote it with the tensorflow backend at hand.

I put the implementation code below. https://github.com/ryuNagai/MPS/blob/master/TN_ML/MNIST_ML_jax.ipynb

The learning process is like this. image.png

Finally, train accuracy = 0.962 and test accuracy = 0.952. In [4], the train accuracy reached about 0.98 at about 50 epoch, which was not enough to reproduce it. Behind the scenes, I tried a little to see if the result of [4] could be reproduced, but it was difficult, so I'm happy with this value.

Summary

We have implemented quantum-inspired machine learning using a new (potentially) popular tensor network. In the current situation where there are many restrictions on the hardware side of quantum computers, this method can be executed on a classical computer, so it can handle big problems. I think it is still up to future research to discover more useful models than conventional machine learning models using this method.

In addition, if there is a possibility that it can be verified, whether or not machine learning using quantum space has an advantage over classical machine learning, by using such a method, it can be seen approximately or indirectly. I think it's good.

References

[1] https://arxiv.org/abs/1008.3477 [2] https://arxiv.org/abs/1805.01450 [3] https://papers.nips.cc/paper/6211-supervised-learning-with-tensor-networks [4] https://arxiv.org/abs/1906.06329 [5] https://arxiv.org/abs/2006.02516

Recommended Posts

Quantum-inspired machine learning with tensor networks
Machine learning learned with Pokemon
Machine learning with Python! Preparation
Machine learning Minesweeper with PyTorch
Beginning with Python machine learning
Try machine learning with Kaggle
I tried machine learning with liblinear
Machine learning with python (1) Overall classification
Machine learning
Try machine learning with scikit-learn SVM
Get started with machine learning with SageMaker
"Scraping & machine learning with Python" Learning memo
Predict power demand with machine learning Part 2
Amplify images for machine learning with python
Machine learning imbalanced data sklearn with k-NN
A story about machine learning with Kyasuket
[Shakyo] Encounter with Python for machine learning
Build AI / machine learning environment with Python
[Memo] Machine learning
Machine learning classification
Machine Learning sample
[Python] Easy introduction to machine learning with python (SVM)
Machine learning starting with Python Personal memorandum Part2
Machine learning starting with Python Personal memorandum Part1
[Python] Collect images with Icrawler for machine learning [1000 images]
Machine learning starting from scratch (machine learning learned with Kaggle)
I started machine learning with Python Data preprocessing
Build a Python machine learning environment with a container
Machine learning tutorial summary
I tried to move machine learning (ObjectDetection) with TouchDesigner
Easy Machine Learning with AutoAI (Part 4) Jupyter Notebook Edition
Machine learning with Raspberry Pi 4 and Coral USB Accelerator
Key points of "Machine learning with Azure ML Studio"
Learning Python with ChemTHEATER 03
"Object-oriented" learning with python
About machine learning overfitting
Learning Python with ChemTHEATER 05-1
Learn collaborative filtering along with Coursera Machine Learning materials
Machine learning ⑤ AdaBoost Summary
Run a machine learning pipeline with Cloud Dataflow (Python)
Machine learning logistic regression
Easy machine learning with scikit-learn and flask ✕ Web app
Try to predict forex (FX) with non-deep machine learning
Machine learning support vector machine
Studying Machine Learning ~ matplotlib ~
Let's feel like a material researcher with machine learning
Learning Python with ChemTHEATER 02
Machine learning linear regression
Machine learning course memo
Predict the gender of Twitter users with machine learning
Build a machine learning application development environment with Python
Machine learning library dlib
Learning Python with ChemTHEATER 01
Machine learning (TensorFlow) + Lotto 6
Site summary to learn machine learning with English video
Somehow learn machine learning
Summary of the basic flow of machine learning with Python
Record of the first machine learning challenge with Keras
Practical machine learning with Scikit-Learn and TensorFlow-TensorFlow gave up-
Machine learning library Shogun
Machine learning rabbit challenge