[PYTHON] I tried to implement sentence classification & Attention visualization by Japanese BERT in PyTorch

Introduction

Thanks to the transformers on huggingface, Japanese BERT models can now be handled very easily using PyTorch.

Many people have already posted articles about Japanese BERT using hugging face / transformers, but I decided to post an article after studying.

reference

[Learn while making! Deep learning by PyTorch](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3 % 81% 8C% E3% 82% 89% E5% AD% A6% E3% 81% B6-PyTorch% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E7% 99% BA% E5 % B1% 95% E3% 83% 87% E3% 82% A3% E3% 83% BC% E3% 83% 97% E3% 83% A9% E3% 83% BC% E3% 83% 8B% E3% 83 Posted by the author of% B3% E3% 82% B0-% E5% B0% 8F% E5% B7% 9D% E9% 9B% 84% E5% A4% AA% E9% 83% 8E / dp / 4839970254) The following articles that have been made are overwhelmingly easy to understand. He politely explains the places where BERT beginners like me are likely to get stuck.

-[Implementation explanation] How to use Japanese version of BERT in Google Colaboratory (PyTorch) -[Implementation explanation] Livedoor news classification in Japanese version BERT: Google Colaboratory (PyTorch)

With reference to the above books & Qiita articles (or almost all sutras), I will also implement sentence classification by BERT. I will also touch on visualization by Attention. For those who want to classify sentences using BERT for the time being, and want to see the visualization of Attention. BERT's theory does not touch on the story at all.

Problem setting

Treat the livedoor news corpus as validation data as usual. The text of livedoor news is used in the reference article, but it is not interesting if it is exactly the same, so the title of the livedoor news corpus is the same as article written in the past. I will try to classify sentences using only.

Implementation

It is implemented on Google Colab as well as the reference article.

Data preparation

First, mount Google Drive on colab

from google.colab import drive
drive.mount('/content/drive')

Get the livedoor news corpus by referring to here. Save the dataset with the title and category of the livedoor news corpus extracted in Google Drive as a DataFrame and store it in Google Drive. After storing, the state of checking the contents of the data is as follows.

import pickle
import pandas as pd

#Data set storage location
drive_dir = "drive/My Drive/Colab Notebooks/livedoor_data/"

with open(drive_dir + "livedoor_title_category.pickle", 'rb') as f:
  livedoor_data = pickle.load(f)

livedoor_data.head()
#title	category
#0 Comfortable internet even overseas! KDDI, "au Wi-Expanding Fi SPOT services it-life-hack
#1 [Special feature/JOURNEY] To an exciting and gentle Arab country (4)/8)	livedoor-homme
#2 Twitter for a single woman, a surprising way to enjoy dokujo-tsushin
#3 The story that the pyramid was built in 20 years is a lie movie-enter
#4 Ayame Goriki presents a “lovely” handmade chocolate cake movie-enter

Let's ID the category.

#Get a list of categories from a dataset
categories = list(set(livedoor_data['category']))
print(categories)
#['topic-news', 'movie-enter', 'livedoor-homme', 'it-life-hack', 'dokujo-tsushin', 'sports-watch', 'kaden-channel', 'peachy', 'smax']

#Create a category ID dictionary
id2cat = dict(zip(list(range(len(categories))), categories))
cat2id = dict(zip(categories, list(range(len(categories)))))
print(id2cat)
print(cat2id)
#{0: 'topic-news', 1: 'movie-enter', 2: 'livedoor-homme', 3: 'it-life-hack', 4: 'dokujo-tsushin', 5: 'sports-watch', 6: 'kaden-channel', 7: 'peachy', 8: 'smax'}
#{'topic-news': 0, 'movie-enter': 1, 'livedoor-homme': 2, 'it-life-hack': 3, 'dokujo-tsushin': 4, 'sports-watch': 5, 'kaden-channel': 6, 'peachy': 7, 'smax': 8}

#Added category ID column to DataFrame
livedoor_data['category_id'] = livedoor_data['category'].map(cat2id)

#Shuffle just in case
livedoor_data = livedoor_data.sample(frac=1).reset_index(drop=True)

#Make the dataset only title and category ID columns
livedoor_data = livedoor_data[['title', 'category_id']]
livedoor_data.head()
#title	category_id
#0 Ninety-nine Okamura rejects AKB special program appearance request "Who appears in such a place ..." 0
#1	C-"Star Wars in Concert" where 3PO introduces famous scenes landed in Japan 1
#2 Deliver the voyeur scene!?A shocking moment was discovered that should be a free event broadcast [Topic] 6
#3 "To Mitsuhiro Oikawa in the final episode of my partner," Ruthless treatment "" and the woman herself 0
#4 Above Hasebe and Kazu? There are 5 surprising athletes in "athletes who like elementary school students"

Since torchtext is used for data preprocessing, separate the dataset for training and testing and save it in a tsv file.

#Divide into training data and test data
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(livedoor_data, train_size=0.8)
print("Training data size", train_df.shape[0])
print("Test data size", test_df.shape[0])
#Training data size 5900
#Test data size 1476

#Save as tsv file
train_df.to_csv(drive_dir + 'train.tsv', sep='\t', index=False, header=None)
test_df.to_csv(drive_dir + 'test.tsv', sep='\t', index=False, header=None)

Install MeCab and huggingface / transformers

I mentioned it in here, but it seems that some caution is required when installing MeCab. Currently, if you install pip as shown below, it works without error.

#Prepare MeCab and transformers
!apt install aptitude swig
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
#Mecab as reported below-python3 version 0.996.If you do not set it to 5, it will fall with tokezer
# https://stackoverflow.com/questions/62860717/huggingface-for-japanese-tokenizer
!pip install mecab-python3==0.996.5
!pip install unidic-lite #Without this, it will fail with an error when executing MeCab
!pip install transformers

Create an iterator with torchtext

You can use tokenizer.encode to execute the word-separation that can be used in the Japanese BERT model, and use tokenizer.convert_ids_to_tokens to convert the word-separated ID string to morphemes and subwords. Very convenient.

import torch
import torchtext
from transformers.modeling_bert import BertModel
from transformers.tokenization_bert_japanese import BertJapaneseTokenizer

#Declares a tokenizer for Japanese BERT word-separation
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

#I'll try to divide it.
text = list(train_df['title'])[0]
wakati_ids = tokenizer.encode(text, return_tensors='pt')
print(tokenizer.convert_ids_to_tokens(wakati_ids[0].tolist()))
print(wakati_ids)
print(wakati_ids.size())
#['[CLS]', 'height', 'But', 'Low', 'Female', 'Is', 'marriage', 'To', 'Unfavorable', '?', '[SEP]']
#tensor([[   2, 7236,   14, 3458,  969,    9, 1519,    7, 9839, 2935,    3]])
#torch.Size([1, 11])

Tohoku University's Japanese pre-learning model that can be handled from the hugging face has a maximum of 512 morphemes (number of subwords) in sentences. So, if the morpheme of the data to be handled and the number of subwords exceed 512, specify max_length to 512. However, as for the titles of this livedoor news corpus, the maximum number is 76 as shown below, so max_length is not specified this time.

#The length of sentences that can be handled by Japanese BERT is 512, but the maximum length of the title of livedoor news is CLS.,76 even with SEP token
import seaborn as sns
title_length = livedoor_data['title'].map(tokenizer.encode).map(len)
print(max(title_length))
# 76

sns.distplot(title_length)

Create an iterator with the following feeling. Since the size of tokenizer.encode is(1 x sentence length), it is necessary to specify[0].

#Create iterators for training and test data using torchtext
def bert_tokenizer(text):
  return tokenizer.encode(text, return_tensors='pt')[0]

TEXT = torchtext.data.Field(sequential=True, tokenize=bert_tokenizer, use_vocab=False, lower=False,
                            include_lengths=True, batch_first=True, pad_token=0)
LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

train_data, test_data = torchtext.data.TabularDataset.splits(
    path=drive_dir, train='train.tsv', test='test.tsv', format='tsv', fields=[('Text', TEXT), ('Label', LABEL)])

#BERT seems to use a mini-batch size of 16 or 32, but livedoor titles have a short sentence length, so even 32 will work on colab.
BATCH_SIZE = 32
train_iter, test_iter = torchtext.data.Iterator.splits((train_data, test_data), batch_sizes=(BATCH_SIZE, BATCH_SIZE), repeat=False, sort=False)

Classification model declaration

Before, let's check the input and output formats of the learned Japanese BERT. The BERT model can be easily declared in one line as follows. Too convenient

from transformers.modeling_bert import BertModel
model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

You can see the structure of BERT by printing the model itself. The output is long, so keep it closed.

BERT model structure
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(32000, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

As you can see from this result, there is an Embedding layer that transforms words into vectors first, and then there are 12 BertLayer. You can also confirm that the number of vector dimensions of the word and the number of dimensions of the hidden layer inside are 768 dimensions.

Let's check the input and output formats of BertModel with reference.

The input format of the BERT model is written as (batch_size, sequence_length). The output seems to return last_hidden_state and pooler_output by default, but Attention weight seems to be obtained by specifying ʻoutput_attentions = True`. Attention returns all the results of each of the 12 Multi head attentions in the 12-layer BertLayer.

#From the test data iterator created above
batch = next(iter(test_iter))
print(batch.Text[0].size())
# torch.Size([32, 48]) ←(batch_size, sequence_length)

#Output during BERT forward propagation_attentions=You can get Attention weight with True
last_hidden_state, pooler_output, attentions = model(batch.Text[0], output_attentions=True)
print(last_hidden_state.size())
print(pooler_output.size())
print(len(attentions), attentions[-1].size())
#torch.Size([32, 48, 768]) ← (batch_size, sequence_length×hidden_size)
#torch.Size([32, 768])
#12 torch.Size([32, 12, 48, 48]) ← (batch_size, num_heads, sequence_length, sequence_length)

When acquiring the sentence vector with BERT, the vector of the cls token at the beginning of each word vector of last_hidden_state is regarded as the sentence vector and used.

Now that we have somehow understood the input and output formats of the BERT model, we will build a model that actually classifies sentences using BERT. As the reference article does, I think that it is better to implement it by yourself instead of using the library for class classification prepared by huggingface, and I think that the structure is easy to understand, so class classification Implement without using the library for.

from torch import nn
import torch.nn.functional as F
from transformers.modeling_bert import BertModel

class BertClassifier(nn.Module):
  def __init__(self):
    super(BertClassifier, self).__init__()
    self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
    #The number of dimensions of the hidden layer of BERT is 768,9 livedoor news categories
    self.linear = nn.Linear(768, 9)
    #Weight initialization processing
    nn.init.normal_(self.linear.weight, std=0.02)
    nn.init.normal_(self.linear.bias, 0)

  def forward(self, input_ids):
    # last_hidden_Receive state and attentions
    vec, _, attentions = self.bert(input_ids, output_attentions=True)
    #Get only the vector of the first token cls
    vec = vec[:,0,:]
    vec = vec.view(-1, 768)
    #Convert dimensions for classification in fully connected layers
    out = self.linear(vec)
    return F.log_softmax(out), attentions

classifier = BertClassifier()

Fine tuning settings

I haven't done fine-tuning until now, but as in the reference article, I turn off all parameters once and then update only the parts where I want to update the parameters. I learned a lot. Furthermore, as for the learning rate, the last layer of BERT has already been pre-learned, so only a few updates will be made, and the last fully connected layer inserted for classification will have a higher learning rate. I see, i see.

#Fine tuning settings
#Perform gradient calculation only for the last BertLayer module and the added classification adapter

#First of all OFF
for param in classifier.parameters():
    param.requires_grad = False

#Update only the last layer of BERT ON
for param in classifier.bert.encoder.layer[-1].parameters():
    param.requires_grad = True

#Class classification is also ON
for param in classifier.linear.parameters():
    param.requires_grad = True

import torch.optim as optim

#The learning rate should be small for the pre-learned part, and large for the last fully connected layer.
optimizer = optim.Adam([
    {'params': classifier.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': classifier.linear.parameters(), 'lr': 1e-4}
])

#Loss function settings
loss_function = nn.NLLLoss()

Learning

As in the reference article, it is actually better to write separately in training mode and verification mode, but for the time being I want to move it, so I am looping with only the minimum code to learn as follows. .. The final accuracy did not change much whether the number of epochs was 5 or 10, so I set the number of epochs to 5. The loss is steadily decreasing, so it's okay for the time being.

#GPU settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Send network to GPU
classifier.to(device)
losses = []

#The number of epochs is 5
for epoch in range(5):
  all_loss = 0
  for idx, batch in enumerate(train_iter):
    batch_loss = 0
    classifier.zero_grad()
    input_ids = batch.Text[0].to(device)
    label_ids = batch.Label.to(device)
    out, _ = classifier(input_ids)
    batch_loss = loss_function(out, label_ids)
    batch_loss.backward()
    optimizer.step()
    all_loss += batch_loss.item()
  print("epoch", epoch, "\t" , "loss", all_loss)
#epoch 0 	 loss 246.03703904151917
#epoch 1 	 loss 108.01931090652943
#epoch 2 	 loss 80.69403756409883
#epoch 3 	 loss 62.87365382164717
#epoch 4 	 loss 50.78619819134474

Accuracy check

Let's look at the F score. The text of the article seems to exceed 90%, but the classification of only the title resulted in 85%. Although the title does have a summary meaning to the article, I was often interested in this short sentence as much as 85%.

from sklearn.metrics import classification_report

answer = []
prediction = []
with torch.no_grad():
    for batch in test_iter:

        text_tensor = batch.Text[0].to(device)
        label_tensor = batch.Label.to(device)

        score, _ = classifier(text_tensor)
        _, pred = torch.max(score, 1)

        prediction += list(pred.cpu().numpy())
        answer += list(label_tensor.cpu().numpy())
print(classification_report(prediction, answer, target_names=categories))
#                precision    recall  f1-score   support
#
#    topic-news       0.80      0.82      0.81       158
#   movie-enter       0.85      0.82      0.83       178
#livedoor-homme       0.68      0.73      0.70       108
#  it-life-hack       0.88      0.82      0.85       179
#dokujo-tsushin       0.82      0.85      0.84       144
#  sports-watch       0.89      0.87      0.88       180
# kaden-channel       0.91      0.97      0.94       180
#        peachy       0.78      0.77      0.78       172
#          smax       0.94      0.91      0.92       177
#
#      accuracy                           0.85      1476
#     macro avg       0.84      0.84      0.84      1476
#  weighted avg       0.85      0.85      0.85      1476

Visualization of Attention

Finally, let's check the basis for judging sentence classification by visualizing Attention. The Attention weight to be visualized was updating the parameters of the last layer of BertLayer when setting fine tuning, that is, the Attention weight of the last layer was learned for this title classification, so the Attention weight of the last layer is It seems that it can be used as a basis for judging this task.

Since the BertClassifer model declared this time returns all Attention weights, get only the last layer as follows and check the size again.

batch = next(iter(test_iter))
score, attentions = classifier(batch.Text[0].to(device))
#Get only the Attention weight of the last layer and check the size
print(attentions[-1].size())
# torch.Size([32, 12, 48, 48])

When I checked the Reference again, the meaning of this size was (batch_size, num_heads, sequence_length, sequence_length). Since the Attention of BertEncoder is Self Attention, how much attention is given to each word of the second sequence_length for each word of the first sequence_length? This time, the sentences were classified using the first token cls, so by visualizing which word the vector of the first token is Attention to, it seems that it can be regarded as the basis for judging this task. In addition, BERT's Self Attention is 12 Multi head attentions, so when visualizing, I will add all 12 Attention weights and use them.

I tried to implement the visualization part as follows with reference to the reference book.

def highlight(word, attn):
  html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
  return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(index, batch, preds, attention_weight):
  sentence = batch.Text[0][index]
  label =batch.Label[index].item()
  pred = preds[index].item()

  label_str = id2cat[label]
  pred_str = id2cat[pred]

  html = "Correct answer category: {}<br>Forecast category: {}<br>".format(label_str, pred_str)

  #Declare zero tensor for the length of the sentence
  seq_len = attention_weight.size()[2]
  all_attens = torch.zeros(seq_len).to(device)

  for i in range(12):
    all_attens += attention_weight[index, i, 0, :]

  for word, attn in zip(sentence, all_attens):
    if tokenizer.convert_ids_to_tokens([word.tolist()])[0] == "[SEP]":
      break
    html += highlight(tokenizer.convert_ids_to_tokens([word.numpy().tolist()])[0], attn)
  html += "<br><br>"
  return html

batch = next(iter(test_iter))
score, attentions = classifier(batch.Text[0].to(device))
_, pred = torch.max(score, 1)

from IPython.display import display, HTML
for i in range(BATCH_SIZE):
  html_output = mk_html(i, batch, pred, attentions[-1])
  display(HTML(html_output))

Here are some visualization results.

――Yodobashi Camera Umeda store is divided by subwords, but it's related to home appliances, so it's a part of it, but it's a solid attention. image.png

――It's interesting to judge it as kaden-channel based on Takahashi Meijin (a person who hits quickly) image.png

--peachy (articles about romance about women). This is also nice. image.png

――Have you been dragged by peachy in the real talk? image.png

I mainly introduced the good ones, but honestly, I thought it was a delicate attention overall. (I'm worried if the implementation is really correct ...) However, I was interested in the fact that it is amazing to pay attention to the parts that are not good even if they are divided into subwords.

in conclusion

Thanks to huggingface / transformers and reference articles, I am able to move BERT, albeit somehow. I want to use BERT for various tasks

end

Recommended Posts

I tried to implement sentence classification & Attention visualization by Japanese BERT in PyTorch
I tried to implement sentence classification by Self Attention with PyTorch
I tried to compare the accuracy of Japanese BERT and Japanese Distil BERT sentence classification with PyTorch & Introduction of BERT accuracy improvement technique
I tried to implement PLSA in Python
I tried to implement permutation in Python
I tried to implement PLSA in Python 2
I tried to implement ADALINE in Python
I tried to implement PPO in Python
I tried to implement CVAE with PyTorch
I tried to implement Bayesian linear regression by Gibbs sampling in python
I tried to implement reading Dataset with PyTorch
[PyTorch] Introduction to Japanese document classification using BERT
I tried to implement selection sort in python
I tried to implement a pseudo pachislot in Python
I tried to implement Dragon Quest poker in Python
I tried to implement GA (genetic algorithm) in Python
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement PCANet
I tried to implement StarGAN (1)
[Django] I tried to implement access control by class inheritance.
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement the mail sending function in Python
I tried to implement blackjack of card game in Python
I tried to implement SSD with PyTorch now (model edition)
I tried to implement Deep VQE
I tried to implement adversarial validation
I tried to explain Pytorch dataset
I tried to implement hierarchical clustering
I tried to implement Realness GAN
I tried to implement a misunderstood prisoner's dilemma game in Python
I tried to implement Autoencoder with TensorFlow
[PyTorch] Introduction to document classification using BERT
I tried to implement a card game of playing cards in Python
I tried to make PyTorch model API in Azure environment using TorchServe
I tried to summarize all the Python visualization tools used in research by active science graduate students [Application]
I tried to implement merge sort in Python with as few lines as possible
[PyTorch] How to use BERT --Fine tuning Japanese pre-trained models to solve classification problems
I tried to predict the change in snowfall for 2 years by machine learning
I tried to implement what seems to be a Windows snipping tool in Python
I tried to program bubble sort by language
I tried to detect Mario with pytorch + yolov3
I tried to get an image by scraping
I tried to integrate with Keras in TFv1.1
I tried to classify dragon ball by adaline
I tried to implement the traveling salesman problem
[Keras] I tried to solve a donut-type region classification problem by machine learning [Study]
[Introduction] I tried to implement it by myself while explaining the binary search tree.
[Series for busy people] I tried to summarize by parsing to call news in 30 seconds
[Introduction] I tried to implement it by myself while explaining to understand the binary tree
[Implementation explanation] How to use the Japanese version of BERT in Google Colaboratory (PyTorch)
I tried to graph the packages installed in Python
I tried to implement multivariate statistical process management (MSPC)
I tried to implement Minesweeper on terminal with python
I tried to implement a recommendation system (content-based filtering)
I tried to implement an artificial perceptron with python
I tried to implement time series prediction with GBDT
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to summarize how to use pandas in python
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement automatic proof of sequence calculation
A super introduction to Django by Python beginners! Part 6 I tried to implement the login function