[PYTHON] [PyTorch] Introduction to document classification using BERT

Introduction

In this article, we will follow the process of fine-tuning a pre-trained BERT model through the task of categorizing English news article headlines. In the case of Japanese, unlike English, morphological analysis is required, but the overall flow is the same as the content of this article.

This implementation is also the answer to question 89 of 100 language processing knock 2020 version. For sample answers to other questions, see [Language Processing 100 Knock 2020] Summary of Answer Examples in Python.

Advance preparation

Google Colaboratory is used for implementation. For details on how to set up and use Google Colaboratory, see [this article](https://cpp-fu learning.com/python_colaboratory/). ** If you want to use GPU for reproduction, please change the hardware accelerator to "GPU" from "Runtime"-> "Change runtime type" and save it in advance. ** ** The notebook containing the execution results is available on github.

Document classification by BERT

News article headings using the public data News Aggregator Data Set are "Business", "Science and Technology", and "Entertainment". We will implement a BERT document classification model for tasks that fall into the "health" category.

Data reading

First, download the target data.

!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00359/NewsAggregatorDataset.zip
!unzip NewsAggregatorDataset.zip
#Check the number of lines
!wc -l ./newsCorpora.csv

output


422937 ./newsCorpora.csv
#Check the first 10 lines
!head -10 ./newsCorpora.csv

output


1	Fed official says weak data caused by weather, should not slow taper	http://www.latimes.com/business/money/la-fi-mo-federal-reserve-plosser-stimulus-economy-20140310,0,1312750.story\?track=rss	Los Angeles Times	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.latimes.com	1394470370698
2	Fed's Charles Plosser sees high bar for change in pace of tapering	http://www.livemint.com/Politics/H2EvwJSK2VE6OF7iK1g3PP/Feds-Charles-Plosser-sees-high-bar-for-change-in-pace-of-ta.html	Livemint	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.livemint.com	1394470371207
3	US open: Stocks fall after Fed official hints at accelerated tapering	http://www.ifamagazine.com/news/us-open-stocks-fall-after-fed-official-hints-at-accelerated-tapering-294436	IFA Magazine	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.ifamagazine.com	1394470371550
4	Fed risks falling 'behind the curve', Charles Plosser says	http://www.ifamagazine.com/news/fed-risks-falling-behind-the-curve-charles-plosser-says-294430	IFA Magazine	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.ifamagazine.com	1394470371793
5	Fed's Plosser: Nasty Weather Has Curbed Job Growth	http://www.moneynews.com/Economy/federal-reserve-charles-plosser-weather-job-growth/2014/03/10/id/557011	Moneynews	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.moneynews.com	1394470372027
6	Plosser: Fed May Have to Accelerate Tapering Pace	http://www.nasdaq.com/article/plosser-fed-may-have-to-accelerate-tapering-pace-20140310-00371	NASDAQ	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.nasdaq.com	1394470372212
7	Fed's Plosser: Taper pace may be too slow	http://www.marketwatch.com/story/feds-plosser-taper-pace-may-be-too-slow-2014-03-10\?reflink=MW_news_stmp	MarketWatch	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.marketwatch.com	1394470372405
8	Fed's Plosser expects US unemployment to fall to 6.2% by the end of 2014	http://www.fxstreet.com/news/forex-news/article.aspx\?storyid=23285020-b1b5-47ed-a8c4-96124bb91a39	FXstreet.com	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	www.fxstreet.com	1394470372615
9	US jobs growth last month hit by weather:Fed President Charles Plosser	http://economictimes.indiatimes.com/news/international/business/us-jobs-growth-last-month-hit-by-weatherfed-president-charles-plosser/articleshow/31788000.cms	Economic Times	b	ddUyU0VZz0BRneMioxUPQVP6sIxvM	economictimes.indiatimes.com	1394470372792
10	ECB unlikely to end sterilisation of SMP purchases - traders	http://www.iii.co.uk/news-opinion/reuters/news/152615	Interactive Investor	b	dPhGU51DcrolUIMxbRm0InaHGA2XM	www.iii.co.uk	1394470501265
#Replaced double quotes with single quotes to avoid errors when reading
!sed -e 's/"/'\''/g' ./newsCorpora.csv > ./newsCorpora_re.csv

Next, read it as a data frame, extract only the cases where the information source (PUBLISHER) is Reuters, Huffington Post, Businessweek, Contactmusic.com, Daily Mail, and then divide it into training data, validation data, and evaluation data.

import pandas as pd
from sklearn.model_selection import train_test_split

#Data reading
df = pd.read_csv('./newsCorpora_re.csv', header=None, sep='\t', names=['ID', 'TITLE', 'URL', 'PUBLISHER', 'CATEGORY', 'STORY', 'HOSTNAME', 'TIMESTAMP'])

#Data extraction
df = df.loc[df['PUBLISHER'].isin(['Reuters', 'Huffington Post', 'Businessweek', 'Contactmusic.com', 'Daily Mail']), ['TITLE', 'CATEGORY']]

#Data split
train, valid_test = train_test_split(df, test_size=0.2, shuffle=True, random_state=123, stratify=df['CATEGORY'])
valid, test = train_test_split(valid_test, test_size=0.5, shuffle=True, random_state=123, stratify=valid_test['CATEGORY'])
train.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

print(train.head())

output


                                               TITLE CATEGORY
0  REFILE-UPDATE 1-European car sales up for sixt...        b
1  Amazon Plans to Fight FTC Over Mobile-App Purc...        t
2  Kids Still Get Codeine In Emergency Rooms Desp...        m
3  What On Earth Happened Between Solange And Jay...        e
4  NATO Missile Defense Is Flight Tested Over Hawaii        b
#Confirmation of the number of cases
print('[Learning data]')
print(train['CATEGORY'].value_counts())
print('[Verification data]')
print(valid['CATEGORY'].value_counts())
print('[Evaluation data]')
print(test['CATEGORY'].value_counts())

output


[Learning data]
b    4501
e    4235
t    1220
m     728
Name: CATEGORY, dtype: int64
[Verification data]
b    563
e    529
t    153
m     91
Name: CATEGORY, dtype: int64
[Evaluation data]
b    563
e    530
t    152
m     91
Name: CATEGORY, dtype: int64

(b: Business, e: Entertainment, t: Science and Technology, m: Health)

Preparing for learning

Install the `transformers``` library to use the BERT model. Through `transformers```, many pretrained models besides BERT can be used very easily with short code.

!pip install transformers

Import the libraries needed to train and evaluate your model.

import numpy as np
import transformers
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torch import optim
from torch import cuda
import time
from matplotlib import pyplot as plt

Next, shape the data into a form that can be populated into the model. First, define a class to create a `Dataset``` that holds the feature vector and the label vector together, which is often used in PyTorch. By passing ``` tokenizer``` to this class, it is possible to preprocess the input text, pad it to the specified longest sequence length, and then convert it to a word ID. However, the tokenizer``` itself, where all the processing is written for BERT, will be obtained later through `` tranformers```, so what you need in the class is `tokenizer. Only the process of passing to `and the process of receiving the result.

#Dataset definition
class CreateDataset(Dataset):
  def __init__(self, X, y, tokenizer, max_len):
    self.X = X
    self.y = y
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):  # len(Dataset)Specify the value to be returned with
    return len(self.y)

  def __getitem__(self, index):  # Dataset[index]Specify the value to be returned with
    text = self.X[index]
    inputs = self.tokenizer.encode_plus(
      text,
      add_special_tokens=True,
      max_length=self.max_len,
      pad_to_max_length=True
    )
    ids = inputs['input_ids']
    mask = inputs['attention_mask']

    return {
      'ids': torch.LongTensor(ids),
      'mask': torch.LongTensor(mask),
      'labels': torch.Tensor(self.y[index])
    }

Create a Dataset using the above. In addition, BERT that can be used as an English version pre-learned model is LARGE, which is a configuration aiming for the highest accuracy, BASE, which has fewer parameters, and 4 of each of them, lowercase only (Uncased) and mixed case (Cased). There is a pattern. This time, we will use BASE's Uncased, which you can easily try.

#Correct label one-Hot
y_train = pd.get_dummies(train, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values
y_valid = pd.get_dummies(valid, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values
y_test = pd.get_dummies(test, columns=['CATEGORY'])[['CATEGORY_b', 'CATEGORY_e', 'CATEGORY_t', 'CATEGORY_m']].values

#Creating a Dataset
max_len = 20
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset_train = CreateDataset(train['TITLE'], y_train, tokenizer, max_len)
dataset_valid = CreateDataset(valid['TITLE'], y_valid, tokenizer, max_len)
dataset_test = CreateDataset(test['TITLE'], y_test, tokenizer, max_len)

for var in dataset_train[0]:
  print(f'{var}: {dataset_train[0][var]}')

output


ids: tensor([  101, 25416,  9463,  1011, 10651,  1015,  1011,  2647,  2482,  4341,
         2039,  2005,  4369,  3204,  2004, 18730,  8980,   102,     0,     0])
mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0])
labels: tensor([1., 0., 0., 0.])

The information of the first sentence is output. You can see that the input string has been converted to an ID series as ids. In BERT, the special delimiters [CLS] and [SEP] are inserted at the beginning and end of the original sentence during the conversion process, so they are also `101``` and 102``. Included in the series as . `0``` represents padding. The correct label is also held in one-hot format as `labels. We also keep a maskthat represents the padding position so that we can pass it to the model along with theids``` during training.

Next, define the network. transfomersBy using, the whole bert partbertmodelIt can be expressed with. Then, to handle the classification task, define a dropout that receives bert's output vector and a fully connected layer, and you're done.

#Definition of BERT classification model
class BERTClass(torch.nn.Module):
  def __init__(self, drop_rate, otuput_size):
    super().__init__()
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.drop = torch.nn.Dropout(drop_rate)
    self.fc = torch.nn.Linear(768, otuput_size)  #Specify 768 dimensions according to the output of BERT
    
  def forward(self, ids, mask):
    _, out = self.bert(ids, attention_mask=mask)
    out = self.fc(self.drop(out))
    return out

Learning the BERT classification model

Now that the `` `Datasetand the network are ready, it's time to create the usual learning loop. Here, a series of flows is defined as a train_model``` function. For the meaning of the components that appear, see the flow of the problem in the article [Language Processing 100 Knock 2020] Chapter 8: Neural Net. Please refer to the explanation along with it.

def calculate_loss_and_accuracy(model, criterion, loader, device):
  """Calculate loss / correct answer rate"""
  model.eval()
  loss = 0.0
  total = 0
  correct = 0
  with torch.no_grad():
    for data in loader:
      #Device specification
      ids = data['ids'].to(device)
      mask = data['mask'].to(device)
      labels = data['labels'].to(device)

      #Forward propagation
      outputs = model.forward(ids, mask)

      #Loss calculation
      loss += criterion(outputs, labels).item()

      #Correct answer rate calculation
      pred = torch.argmax(outputs, dim=-1).cpu().numpy() #Predicted label array for batch size length
      labels = torch.argmax(labels, dim=-1).cpu().numpy()  #Batch size length correct label array
      total += len(labels)
      correct += (pred == labels).sum().item()
      
  return loss / len(loader), correct / total
  

def train_model(dataset_train, dataset_valid, batch_size, model, criterion, optimizer, num_epochs, device=None):
  """Executes model training and returns a log of loss / correct answer rate"""
  #Device specification
  model.to(device)

  #Creating a dataloader
  dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
  dataloader_valid = DataLoader(dataset_valid, batch_size=len(dataset_valid), shuffle=False)

  #Learning
  log_train = []
  log_valid = []
  for epoch in range(num_epochs):
    #Record start time
    s_time = time.time()

    #Set to training mode
    model.train()
    for data in dataloader_train:
      #Device specification
      ids = data['ids'].to(device)
      mask = data['mask'].to(device)
      labels = data['labels'].to(device)

      #Initialize gradient to zero
      optimizer.zero_grad()

      #Forward propagation+Backpropagation of error+Weight update
      outputs = model.forward(ids, mask)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      
    #Calculation of loss and correct answer rate
    loss_train, acc_train = calculate_loss_and_accuracy(model, criterion, dataloader_train, device)
    loss_valid, acc_valid = calculate_loss_and_accuracy(model, criterion, dataloader_valid, device)
    log_train.append([loss_train, acc_train])
    log_valid.append([loss_valid, acc_valid])

    #Save checkpoint
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, f'checkpoint{epoch + 1}.pt')

    #Record end time
    e_time = time.time()

    #Output log
    print(f'epoch: {epoch + 1}, loss_train: {loss_train:.4f}, accuracy_train: {acc_train:.4f}, loss_valid: {loss_valid:.4f}, accuracy_valid: {acc_valid:.4f}, {(e_time - s_time):.4f}sec') 

  return {'train': log_train, 'valid': log_valid}

Set the parameters and perform fine tuning.

#Parameter setting
DROP_RATE = 0.4
OUTPUT_SIZE = 4
BATCH_SIZE = 32
NUM_EPOCHS = 4
LEARNING_RATE = 2e-5

#Model definition
model = BERTClass(DROP_RATE, OUTPUT_SIZE)

#Definition of loss function
criterion = torch.nn.BCEWithLogitsLoss()

#Optimizer definition
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)

#Device specification
device = 'cuda' if cuda.is_available() else 'cpu'

#Model learning
log = train_model(dataset_train, dataset_valid, BATCH_SIZE, model, criterion, optimizer, NUM_EPOCHS, device=device)

output


epoch: 1, loss_train: 0.0859, accuracy_train: 0.9516, loss_valid: 0.1142, accuracy_valid: 0.9229, 49.9137sec
epoch: 2, loss_train: 0.0448, accuracy_train: 0.9766, loss_valid: 0.1046, accuracy_valid: 0.9259, 49.7376sec
epoch: 3, loss_train: 0.0316, accuracy_train: 0.9831, loss_valid: 0.1082, accuracy_valid: 0.9266, 49.5454sec
epoch: 4, loss_train: 0.0170, accuracy_train: 0.9932, loss_valid: 0.1179, accuracy_valid: 0.9289, 49.4525sec

Check the result.

#Log visualization
x_axis = [x for x in range(1, len(log['train']) + 1)]
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].plot(x_axis, np.array(log['train']).T[0], label='train')
ax[0].plot(x_axis, np.array(log['valid']).T[0], label='valid')
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('loss')
ax[0].legend()
ax[1].plot(x_axis, np.array(log['train']).T[1], label='train')
ax[1].plot(x_axis, np.array(log['valid']).T[1], label='valid')
ax[1].set_xlabel('epoch')
ax[1].set_ylabel('accuracy')
ax[1].legend()
plt.show()

89.png

#Calculation of correct answer rate
def calculate_accuracy(model, dataset, device):
  #Creating a Dataloader
  loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

  model.eval()
  total = 0
  correct = 0
  with torch.no_grad():
    for data in loader:
      #Device specification
      ids = data['ids'].to(device)
      mask = data['mask'].to(device)
      labels = data['labels'].to(device)

      #Forward propagation+Get predicted value+Counting the number of correct answers
      outputs = model.forward(ids, mask)
      pred = torch.argmax(outputs, dim=-1).cpu().numpy()
      labels = torch.argmax(labels, dim=-1).cpu().numpy()
      total += len(labels)
      correct += (pred == labels).sum().item()

  return correct / total

print(f'Correct answer rate (learning data):{calculate_accuracy(model, dataset_train, device):.3f}')
print(f'Correct answer rate (verification data):{calculate_accuracy(model, dataset_valid, device):.3f}')
print(f'Correct answer rate (evaluation data):{calculate_accuracy(model, dataset_test, device):.3f}')

output


Correct answer rate (learning data): 0.993
Correct answer rate (verification data): 0.929
Correct answer rate (evaluation data): 0.948

The correct answer rate was about 95% in the evaluation data.

Normally, I think that it is often the case that parameters such as whether or not the weight is fixed for each layer of BERT and the learning rate are adjusted while checking the accuracy of the verification data. This time, the parameters were fixed, but the accuracy was relatively high, and the result showed the strength of pre-learning.

reference

transformers BERT (official) BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin, J. et al. (2018) (Original Article) [Language processing 100 knock 2020] Summary of answer examples by Python

Recommended Posts

[PyTorch] Introduction to document classification using BERT
[PyTorch] Introduction to Japanese document classification using BERT
Introduction to Lightning pytorch
Introduction to PyTorch (1) Automatic differentiation
Introduction to discord.py (3) Using voice
[Details (?)] Introduction to pytorch ~ CNN CIFAR10 ~
Introduction to Discrete Event Simulation Using Python # 1
[Python] Introduction to CNN with Pytorch MNIST
Document classification with toch text from PyTorch
[Introduction to Pytorch] I played with sinGAN ♬
Create document classification data quickly using NLTK
Introduction to Discrete Event Simulation Using Python # 2
Introduction to Tornado (3): Development using templates [Practice]
I tried to compare the accuracy of Japanese BERT and Japanese Distil BERT sentence classification with PyTorch & Introduction of BERT accuracy improvement technique
Introduction to MQTT (Introduction)
Introduction to Scrapy (1)
[Super Introduction to Machine Learning] Learn Pytorch tutorials
Introduction to Scrapy (3)
Introduction to Tkinter 1: Introduction
pytorch super introduction
Introduction to PyQt
Introduction to Scrapy (2)
[PyTorch] How to use BERT --Fine tuning Japanese pre-trained models to solve classification problems
[Linux] Introduction to Linux
[Super Introduction to Machine Learning] Learn Pytorch tutorials
Introduction to Scrapy (4)
Introduction to discord.py (2)
Introduction to Tornado (2): Introduction to development using templates-Dynamic page generation-
I tried to implement sentence classification & Attention visualization by Japanese BERT in PyTorch
Introduction to Scapy ② (ICMP, HTTP (TCP) transmission using Scapy)
Introduction to discord.py
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
[Introduction to Python] How to stop the loop using break?
[Introduction to cx_Oracle] (Part 13) Connection using connection pool (client side)
[Introduction to Python] How to write repetitive statements using for statements
[Technical book] Introduction to data analysis using Python -1 Chapter Introduction-
Introduction to Web Scraping
Introduction to Nonparametric Bayes
Introduction to EV3 / MicroPython
Introduction to Python language
Introduction to TensorFlow-Image Recognition
Introduction to OpenCV (python)-(2)
[Pytorch] numpy to tensor
Introduction to PyQt4 Part 1
Introduction to Dependency Injection
Introduction to Private Chainer
PyTorch introduction (virtual environment)
PyTorch Super Introduction PyTorch Basics
Introduction to machine learning
[Introduction to Python] How to write conditional branches using if statements
Day 67 [Introduction to Kaggle] Have you tried using Random Forest?
Introduction to Bayesian Modeling Using pymc3 Bayesian-Modeling-in-Python Japanese Translation (Chapter 0-2)
[Introduction to Pytorch] I want to generate sentences in news articles
[Python] Introduction to graph creation using coronavirus data [For beginners]
Try to implement linear regression using Pytorch with Google Colaboratory