[PYTHON] I implemented Attention Seq2Seq with PyTorch

Introduction

Following the implementation of Seq2Seq in previous, this time I implemented Attention Seq2Seq with Attention added to Seq2Seq with PyTorch.

Even beginners like myself can't find much source code that implements Attention in PyTorch, and there is also PyTorch Attention Tutorial. There is, but it seems that I have not learned mini-batch (?), And I wanted to implement a simpler plain (?) Attention with a feeling that it seems to be customized for this task. I tried to implement Attention myself. We hope that we can provide you with some helpful information for those who are having trouble implementing Attention.

The mechanism of Attention is still [Deep Learning from scratch ❷ ― Natural language processing](https://www.amazon.co.jp/%E3%82%BC%E3%83%AD%E3%81%8B%] E3% 82% 89% E4% BD% 9C% E3% 82% 8BDeep-Learning-% E2% 80% 95% E8% 87% AA% E7% 84% B6% E8% A8% 80% E8% AA% 9E % E5% 87% A6% E7% 90% 86% E7% B7% A8-% E6% 96% 8E% E8% 97% A4-% E5% BA% B7% E6% AF% 85 / dp / 4873118360 / ref = sr_1_2? __ mk_ja_JP =% E3% 82% AB% E3% 82% BF% E3% 82% AB% E3% 83% 8A & keywords =% E3% 82% BC% E3% 83% AD% E3% 81% 8B% E3 % 82% 89% E4% BD% 9C% E3% 82% 8B & qid = 1568304570 & s = gateway & sr = 8-2) was overwhelmingly easy to understand.

The implementation example I will introduce is just a scratch implementation of Zero Work 2 (should be), so if this article is difficult to understand, I strongly recommend that you read Zero Work 2. ..

Supplement

I think there are various types of Attention such as soft Attention and hard Attention, but the Attention here is Deep Learning from scratch ❷ ― Natural language processing. 82% BC% E3% 83% AD% E3% 81% 8B% E3% 82% 89% E4% BD% 9C% E3% 82% 8BDeep-Learning-% E2% 80% 95% E8% 87% AA% E7 % 84% B6% E8% A8% 80% E8% AA% 9E% E5% 87% A6% E7% 90% 86% E7% B7% A8-% E6% 96% 8E% E8% 97% A4-% E5 % BA% B7% E6% AF% 85 / dp / 4873118360 / ref = sr_1_2? __mk_ja_JP =% E3% 82% AB% E3% 82% BF% E3% 82% AB% E3% 83% 8A & keywords =% E3% 82 % BC% E3% 83% AD% E3% 81% 8B% E3% 82% 89% E4% BD% 9C% E3% 82% 8B & qid = 1568304570 & s = gateway & sr = 8-2) (soft) Let's refer to Attention.

Attention mechanism

Challenges of Seq2Seq

Seq2Seq has the problem that the characteristics of long series cannot be captured because the Encoder converts it to a fixed-length vector regardless of the length of the input series. Attention provides a mechanism that can consider the length of the input sequence on the Encoder side in order to solve this problem.

Super rough explanation

If you explain Attention very roughly

  1. ** Pass all the values of each hidden layer on the Encoder side to each layer on the Decoder side **
  2. ** In each layer on the Decoder side, select the most noteworthy vector from the vectors of each hidden layer passed from the Encoder side and add it to the features **

I will do the operation. In 1., the number of hidden layer vectors on the Encoder side depends on the length of the series that is the input on the Encoder side, so the shape takes into account the length of the series. In 2., the operation of selecting cannot be differentiated, but the operation of selecting where to pay attention to each element is stochastically weighted by $ softmax $.

Explain the Attention processing flow in a little more detail using figures

For the sake of simplicity, the figure below deals with two cases where the Encoder side has three input sequences w1, w2, and w3, and the Decoder side has w'1, w'2.

① When the value of each hidden layer on the Encoder side is $ h_1 $, $ h_2 $, $ \ cdots $, $ h_n $, $ hs = [h_1, h_2, \ cdots, h_n] $ is each layer on the Decoder side. Pass to.

(2) Calculate the inner product of the vector of each hidden layer on the Decoder side (here, $ d_i $) and each vector of $ hs $ $ h_1, h_2, \ cdots $. This means that we are calculating how similar each vector on the Decoder side and each vector on $ hs $ are. (The inner product is expressed as $ (\ cdot, \ cdot) $.)

③ Convert the inner product calculated in ② to a probability expression with $ softmax $ (this is called attention weight)

④ Weight each element of $ hs $ with attention weight and add them all to make one vector (this is called a context vector).

⑤ Combine the context vector and $ d_i $ into a single vector

Implementation

――Add the processes 1 to 5 explained above to the Decoder side and you're done. It deals with the date format conversion problem as well as Zero Saku 2. (Because it is easy to confirm the certainty when the attention weight is visualized) --The following is implemented on Google Colab. -Since I will explain by adding Attention processing to the implementation of Seq2Seq explained in Last time, most of the previous source is used. Please also refer to the previous source code. -I implemented Seq2Seq with PyTorch

Problem setting

Let's solve the task of converting various date writing methods such as the following to the YYYY-MM-DD format with Attention seq 2seq.

Before conversion After conversion
Nobenver, 30, 1995 1995-11-30
Monday, July 9, 2001 2001-07-09
1/23/01 2001-01-23
WEDNESDAY, AUGUST 1, 2001 2001-08-01
sep 7, 1981 1981-09-07

Data preparation

We borrow data from the Github repository of Zero Work 2. https://github.com/oreilly-japan/deep-learning-from-scratch-2/tree/master/dataset

Put this file on Google Drive and separate it before and after conversion as follows.

from sklearn.model_selection import train_test_split
import random
from sklearn.utils import shuffle

#Mount Google Drive in advance and date to the following location.Store txt
file_path = "drive/My Drive/Colab Notebooks/date.txt"

input_date = [] #Date data before conversion
output_date = [] #Date data after conversion

# date.Read txt line by line, divide before and after conversion, and separate by input and output
with open(file_path, "r") as f:
  date_list = f.readlines()
  for date in date_list:
    date = date[:-1]
    input_date.append(date.split("_")[0])
    output_date.append("_" + date.split("_")[1])

#Get the length of the input and output series
#Since they are all the same length, we take len at the 0th element
input_len = len(input_date[0]) # 29
output_len = len(output_date[0]) # 10

# date.Assign an ID to every character that appears in txt
char2id = {}
for input_chars, output_chars in zip(input_date, output_date):
  for c in input_chars:
    if not c in char2id:
      char2id[c] = len(char2id)
  for c in output_chars:
    if not c in char2id:
      char2id[c] = len(char2id)

input_data = [] #IDized pre-conversion date data
output_data = [] #ID-ized converted date data
for input_chars, output_chars in zip(input_date, output_date):
  input_data.append([char2id[c] for c in input_chars])
  output_data.append([char2id[c] for c in output_chars])

# 7:Divide into train and test in 3
train_x, test_x, train_y, test_y = train_test_split(input_data, output_data, train_size= 0.7)

#Define a function to batch data
def train2batch(input_data, output_data, batch_size=100):
    input_batch = []
    output_batch = []
    input_shuffle, output_shuffle = shuffle(input_data, output_data)
    for i in range(0, len(input_data), batch_size):
      input_batch.append(input_shuffle[i:i+batch_size])
      output_batch.append(output_shuffle[i:i+batch_size])
    return input_batch, output_batch

Encoder --The Encoder side is almost the same as the previously implemented seq2seq. ――I want to have a little fun, so I changed LSTM to GRU. --Since the value of each hidden layer of GRU is used for Attention on the Decoder side, the first return value ($ hs $) of GRU is also received.

import torch
import torch.nn as nn
import torch.optim as optim

#Various parameters, etc.
embedding_dim = 200
hidden_dim = 128
BATCH_NUM = 100
vocab_size = len(char2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Encoder class
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=char2id[" "])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
    
    def forward(self, sequence):
        embedding = self.word_embeddings(sequence)
        #hs is the vector of the hidden layer of GRU of each series
        #Attention element
        hs, h = self.gru(embedding)
        return hs, h

Decoder ――Similar to the Encoder side, LSTM is changed to GRU compared to the previous time. ――If you implement it while writing on a piece of paper what axis of the tensor of each layer means, you can organize your head. ――I also listed the size of each tensor in the Attention layer to help you understand it.

#Attention Decoder class
class AttentionDecoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size):
        super(AttentionDecoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=char2id[" "])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        # hidden_dim*2 is torch the context vector calculated by the hidden layer and Attention layer of each series of GRU..Because the length is doubled by connecting with cat
        self.hidden2linear = nn.Linear(hidden_dim * 2, vocab_size)
        #I want to convert the column direction with probability, so dim=1
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, sequence, hs, h):
        embedding = self.word_embeddings(sequence)
        output, state = self.gru(embedding, h)

       #Attention layer
       # hs.size() = ([100, 29, 128])
       # output.size() = ([100, 10, 128])

       #Output on the Encoder side using bmm(hs)And the output on the Decoder side(output)In order to calculate the matrix for each batch, the output on the Decoder side is fixed to batch and a transposed matrix is taken.
        t_output = torch.transpose(output, 1, 2) # t_output.size() = ([100, 128, 10])

        #Matrix calculation with bmm considering batch
        s = torch.bmm(hs, t_output) # s.size() = ([100, 29, 10])

        #Column direction(dim=1)Take softmax and convert it to a probabilistic expression
        #Since this value will be used for visualization of Attention later, return it with return.
        attention_weight = self.softmax(s) # attention_weight.size() = ([100, 29, 10])

        #Prepare a container to organize the context vectors
        c = torch.zeros(self.batch_size, 1, self.hidden_dim, device=device) # c.size() = ([100, 1, 128])

        #I didn't know how to calculate the context vector for each Decoder's GRU layer at once, so
        #Take out the attention weight in each layer (the GRU layer on the Decoder side has 10 characters because the generated character string is 10 characters) and create one context vector in the for loop.
        #Since the batch direction could be calculated collectively, the batch remains as it is
        for i in range(attention_weight.size()[2]): #10 loops

          # attention_weight[:,:,i].size() = ([100, 29])
          #Take the attention weight for the i-th GRU layer, but unsqueeze it to align the tensor size with hs.
          unsq_weight = attention_weight[:,:,i].unsqueeze(2) # unsq_weight.size() = ([100, 29, 1])

          #Weight each vector of hs by attention weight
          weighted_hs = hs * unsq_weight # weighted_hs.size() = ([100, 29, 128])

          #Create a context vector by adding all the vectors of each hs weighted by attention weight
          weight_sum = torch.sum(weighted_hs, axis=1).unsqueeze(1) # weight_sum.size() = ([100, 1, 128])

          c = torch.cat([c, weight_sum], dim=1) # c.size() = ([100, i, 128])

        #Since the zero element prepared as a box remains, slice it and delete it
        c = c[:,1:,:]
        
        output = torch.cat([output, c], dim=2) # output.size() = ([100, 10, 256])
        output = self.hidden2linear(output)
        return output, state, attention_weight

Model declaration, loss function, optimization

--No particular change from last time


encoder = Encoder(vocab_size, embedding_dim, hidden_dim).to(device)
attn_decoder = AttentionDecoder(vocab_size, embedding_dim, hidden_dim, BATCH_NUM).to(device)

#Loss function
criterion = nn.CrossEntropyLoss()

#optimisation
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
attn_decoder_optimizer = optim.Adam(attn_decoder.parameters(), lr=0.001)

Learning

--Don't forget to pass the Encoder output $ hs $ to the Attention Decoder. --Since there is no change in the input and output of both Encoder and Decoder, it is almost the same as the previous Seq2Seq. ――Loss will decrease with tremendous momentum ――In the following, the lower limit of loss is set to 0.1, but it has already reached the 16th epoch.

BATCH_NUM=100
EPOCH_NUM = 100

all_losses = []
print("training ...")
for epoch in range(1, EPOCH_NUM+1):
    epoch_loss = 0
    #Divide the data into mini-batch
    input_batch, output_batch = train2batch(train_x, train_y, batch_size=BATCH_NUM)
    for i in range(len(input_batch)):
        
        #Gradient initialization
        encoder_optimizer.zero_grad()
        attn_decoder_optimizer.zero_grad()
        
        #Convert data to tensor
        input_tensor = torch.tensor(input_batch[i], device=device)
        output_tensor = torch.tensor(output_batch[i], device=device)
        
        #Encoder forward propagation
        hs, h = encoder(input_tensor)

        #Attention Decoder Input
        source = output_tensor[:, :-1]
        
        #Correct answer data of Attention Decoder
        target = output_tensor[:, 1:]

        loss = 0
        decoder_output, _, attention_weight= attn_decoder(source, hs, h)
        for j in range(decoder_output.size()[1]):
            loss += criterion(decoder_output[:, j, :], target[:, j])

        epoch_loss += loss.item()
        
        #Backpropagation of error
        loss.backward()

        #Parameter update
        encoder_optimizer.step()
        attn_decoder_optimizer.step()
    
    #Show loss
    print("Epoch %d: %.2f" % (epoch, epoch_loss))
    all_losses.append(epoch_loss)
    if epoch_loss < 0.1: break
print("Done")
# training ...
# Epoch 1: 1500.33
# Epoch 2: 77.53
# Epoch 3: 12.98
# Epoch 4: 3.40
# Epoch 5: 1.78
# Epoch 6: 1.13
# Epoch 7: 0.78
# Epoch 8: 0.56
# Epoch 9: 0.42
# Epoch 10: 0.32
# Epoch 11: 0.25
# Epoch 12: 0.20
# Epoch 13: 0.16
# Epoch 14: 0.13
# Epoch 15: 0.11
# Epoch 16: 0.09
# Done

Loss visualization

import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(all_losses)

Forecast

――It is predicted by almost the same method as the prediction at the time of the previous Seq2Seq.

import pandas as pd

#Returns the index with the largest element from the Decoder's output tensor. That means the generated character
def get_max_index(decoder_output):
  results = []
  for h in decoder_output:
    results.append(torch.argmax(h))
  return torch.tensor(results, device=device).view(BATCH_NUM, 1)
    
#Evaluation data
test_input_batch, test_output_batch = train2batch(test_x, test_y)
input_tensor = torch.tensor(test_input_batch, device=device)

predicts = []
for i in range(len(test_input_batch)):
  with torch.no_grad():
    hs, encoder_state = encoder(input_tensor[i])
    
    #Decoder first indicates the start of character string generation"_"Because it is an input"_"Create tensor for batch size
    start_char_batch = [[char2id["_"]] for _ in range(BATCH_NUM)]
    decoder_input_tensor = torch.tensor(start_char_batch, device=device)

    decoder_hidden = encoder_state
    batch_tmp = torch.zeros(100,1, dtype=torch.long, device=device)
    for _ in range(output_len - 1):
      decoder_output, decoder_hidden, _ = attn_decoder(decoder_input_tensor, hs, decoder_hidden)
      #While acquiring the predicted character, it becomes the input of the next decoder as it is
      decoder_input_tensor = get_max_index(decoder_output.squeeze())
      batch_tmp = torch.cat([batch_tmp, decoder_input_tensor], dim=1)
    predicts.append(batch_tmp[:,1:])


#Readability is poor if the ID remains as it is when viewing the prediction result, so define a dictionary to convert from ID to character string to restore to the original character string.
id2char = {}
for k, v in char2id.items():
  id2char[v] = k

row = []
for i in range(len(test_input_batch)):
  batch_input = test_input_batch[i]
  batch_output = test_output_batch[i]
  batch_predict = predicts[i]
  for inp, output, predict in zip(batch_input, batch_output, batch_predict):
    x = [id2char[idx] for idx in inp]
    y = [id2char[idx] for idx in output[1:]]
    p = [id2char[idx.item()] for idx in predict]
    
    x_str = "".join(x)
    y_str = "".join(y)
    p_str = "".join(p)
    
    judge = "O" if y_str == p_str else "X"
    row.append([x_str, y_str, p_str, judge])
predict_df = pd.DataFrame(row, columns=["input", "answer", "predict", "judge"])
predict_df.head()

Correct answer rate

――It happened that it wasn't 100% this time, but I think it will be about 100% correct answer rate.

print(len(predict_df.query('judge == "O"')) / len(predict_df))
# 0.9999333333333333

predict_df.query('judge == "X"').head(10)

――I made a mistake in the following one case ――When you make a mistake in this task, it seems that there are many date formats separated by slashes as shown below.

attention weight visualization

――Let's visualize attention weight, which is one of the real thrills of Attention. ――You can check the certainty of learning by looking at the attention weight. --Since heatmap is often used to visualize attention weight, it is visualized with seaborn heatmap. ――The first mini-batch of the test data of 3 divided into 7: 3 is sent.

import seaborn as sns
import pandas as pd

input_batch, output_batch = train2batch(test_x, test_y, batch_size=BATCH_NUM)
input_minibatch, output_minibatch = input_batch[0], output_batch[0]

with torch.no_grad():
  #Convert data to tensor
  input_tensor = torch.tensor(input_minibatch, device=device)
  output_tensor = torch.tensor(output_minibatch, device=device)
  hs, h = encoder(input_tensor)
  source = output_tensor[:, :-1]
  decoder_output, _, attention_weight= attn_decoder(source, hs, h)


for i in range(3):
  with torch.no_grad():
    df = pd.DataFrame(data=torch.transpose(attention_weight[i], 0, 1).cpu().numpy(), 
                      columns=[id2char[idx.item()] for idx in input_tensor[i]], 
                      index=[id2char[idx.item()] for idx in output_tensor[i][1:]])
    plt.figure(figsize=(12, 8)) 
    sns.heatmap(df, xticklabels = 1, yticklabels = 1, square=True, linewidths=.3,cbar_kws = dict(use_gridspec=False,location="top"))

Introducing some visualizations

It's a little hard to see, but the characters "Tuesday, March 27, 2012" at the bottom of the above figure are the characters before conversion (Encoder input), and "2012-03-27" arranged vertically on the left is generated. It is a character. This is how to read the heatmap, but when you look at the characters generated by Decoder one by one, it means that the characters in the boxes on the left are the ones that are generated with the brightest color. I think it will be. (Please point out if it is different ...) (Of course, if you add all the values in the box to the left, it will be 1.)

In the example above, you can see the following.

――It can be seen that you are paying attention to the year part if you generate YYYY as a whole, and the month part if you generate MM. --This task is converted to YYYY-MM-DD, that is, the day of the week is not converted, so I do not pay attention to any generated characters in "Tuesday" --"0" is the attention of the "a" part of "March". "05" for "May" and "04" for "March", but if the letters "Ma" are lined up, the generation of "0" is confirmed, and then the letters "rch" are lined up, so the last Do you feel that 3 is paying attention to the "h" of?

Besides, Attention is done like this ↓

in conclusion

――It seems that there are various patterns in Attention as described in Zero work 2. ――Next, we will deal with (?) Self-Attention, which is more versatile than Attention!

end

Recommended Posts

I implemented Attention Seq2Seq with PyTorch
I implemented VQE with Blueqat
I made Word2Vec with Pytorch
I tried implementing DeepPose with PyTorch
Seq2Seq (2) ~ Attention Model edition ~ with chainer
I tried to implement sentence classification by Self Attention with PyTorch
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement CVAE with PyTorch
I implemented CycleGAN (1)
Cross-validation with PyTorch
Beginning with PyTorch
Seq2Seq (1) with chainer
I implemented ResNet!
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I rewrote Chainer's MNIST code with PyTorch + Ignite
Use RTX 3090 with PyTorch
I tried to move Faster R-CNN quickly with pytorch
I tried to implement and learn DCGAN with PyTorch
I played with wordcloud!
Qiskit: I implemented VQE
I implemented Python Logging
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
Install torch-scatter with PyTorch 1.7
I implemented collaborative filtering (recommendation) with redis and python
I tried to implement SSD with PyTorch now (Dataset)
I got an error when using Tensorboard with Pytorch
I implemented the FloodFill algorithm with TRON BATTLE of CodinGame.
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement SSD with PyTorch now (model edition)
I tried fp-growth with python
I tried scraping with Python
I wrote GP with numpy
I tried Learning-to-Rank with Elasticsearch!
Try an autoencoder with Pytorch
I made blackjack with python!
I tried clustering with PyCaret
Try implementing XOR with PyTorch
Implemented SMO with Python + NumPy
Seq2Seq (3) ~ CopyNet Edition ~ with chainer
Implement PyTorch + GPU with Docker
Prediction of Nikkei 225 with Pytorch 2
Machine learning Minesweeper with PyTorch
AWS Lambda with PyTorch [Lambda import]
Implemented Conditional GAN with chainer
I can't search with # google-map. ..
Implemented Efficient GAN with keras
Prediction of Nikkei 225 with Pytorch
Perform Stratified Split with PyTorch
I implemented Extreme learning machine
I measured BMI with tkinter
I tried gRPC with Python
I made COVID19_simulator with JupyterLab
I tried scraping with python
Implemented SmoothGrad with Chainer v2
I made blackjack with Python.
Zura with softmax function implemented
I made wordcloud with Python.
[Text classification] I implemented Convolutional Neural Networks for Sentence Classification with Chainer