[PYTHON] Document classification with toch text from PyTorch

Introduction

I will explain the implementation flow of document classification using torchtext according to Official tutorial. Also, the Google Colabolatry that accompanies the official tutorial gives an error. I will post the code after correcting the part that is. Finally, I will explain the source code of torchtext.datasets.text_classification.

Development environment

Google Colabolatry

Prior knowledge

Basic terms for natural language processing, such as N-gram

Document classification flow

When classifying documents using torchtext, the implementation flow is as follows. We'll look at the code in the next section, so this theory only gives an overview.

  1. pip install
  2. module import
  3. Storage of dataset, division into train and test
  4. Model definition
  5. Function definition for instance conversion of model and batch generation
  6. Function definition for train, test
  7. Run train, test

code

Let's check the above flow with the code in the tutorial.

  1. pip install It's almost a punch line, but officially this code is causing an error. Specifically, the second line is the cause.
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline

If you execute it as it is, the following error will occur when importing the module described later.

from torchtext.datasets import text_classification

ImportError: cannot import name 'text_classification'

The correct code looks like this: Also, the runtime may be required to be initialized by changing the version of torchtext. In that case, just run the restart runtime and run the cells again from top to bottom (you don't have to press restart runtime after the second pip install).

!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline

The cause is the version of torchtext. If you do pip install without specifying anything, 0.3.1 will be installed. Since text_classification is implemented in 0.4 or later, it cannot be used as it is in 0.3. In the above, it is fixed at 0.5, but there is no problem if it is 0.4 or later.

2. module import

import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os

3. Storage of dataset, division into train and test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4. Model definition

It has a simple flow of embedding → linear. In init_weight, the weights are initialized with the weights generated from the uniform distribution.

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

5. Function definition for instance conversion of model and batch generation

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

6. Function definition for train, test

from torch.utils.data import DataLoader

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()
    
    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

7. Run train, test

If you are learning correctly, you can achieve an accuracy of 0.9 or higher.

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

Commentary

In TORCHTEXT.DATASETS.TEXT_CLASSIFICATION, processing is performed to provide the necessary data literally. On the contrary, no other operation is performed. In other words, the goal of this module is to format the data required for training for various datasets. Therefore, this time, I will focus on the flow of providing train and test datasets. The source code described in the following description is here. First, I will repost the following code.

3. Storage of dataset, division into train and test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Here, you can see that a directory called .data is created, and that directory is used as root to generate the train and test datasets. However, this alone has various unclear points including .data. So, let's actually read the code and see more specific processing.

Data provided by TORCHTEXT.DATASETS.TEXT_CLASSIFICATION

Some data is provided for document classification. The data currently provided is as follows.

If you want to get each data directly, you can download it from the url described in the URLS variable.

URLS = {
    'AG_NEWS':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
    'SogouNews':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
    'DBpedia':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
    'YelpReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
    'YelpReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
    'YahooAnswers':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
    'AmazonReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
    'AmazonReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}

Now, let's actually follow the processing of data through the source code. The first thing that is done is the definition of the function.

def AG_NEWS(*args, **kwargs):
    """ Defines AG_NEWS datasets.
        The labels includes:
            - 1 : World
            - 2 : Sports
            - 3 : Business
            - 4 : Sci/Tech

    Create supervised learning dataset: AG_NEWS

    Separately returns the training and test dataset

    Arguments:
        root: Directory where the datasets are saved. Default: ".data"
        ngrams: a contiguous sequence of n items from s string text.
            Default: 1
        vocab: Vocabulary used for dataset. If None, it will generate a new
            vocabulary based on the train data set.
        include_unk: include unknown token in the data (Default: False)

    Examples:
        >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)

    """

    return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)

You can see that the formatted data is returned using the _setup_datasets function. From now on, only AG_NEWS is targeted, but the same processing is performed for other data sets. Next, register the defined function in the DATASETS variable in dict format.

DATASETS = {
    'AG_NEWS': AG_NEWS,
    'SogouNews': SogouNews,
    'DBpedia': DBpedia,
    'YelpReviewPolarity': YelpReviewPolarity,
    'YelpReviewFull': YelpReviewFull,
    'YahooAnswers': YahooAnswers,
    'AmazonReviewPolarity': AmazonReviewPolarity,
    'AmazonReviewFull': AmazonReviewFull
}

In addition, the label information for each data set is stored in the LABELS variable in dict format.

LABELS = {
    'AG_NEWS': {1: 'World',
                2: 'Sports',
                3: 'Business',
                4: 'Sci/Tech'},
}

Although omitted here, labels other than AG_NEWS are stored in the same format. Since the function is registered in dict format in the above DATASETS variable, the following two refer to the same thing.

text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS

Check the processing of the data by looking at the _setup_datasets function.

def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))

The main processing is as follows.

  1. Save the document data in the directory specified by the download_from_url function.
  2. Create the word data used in the data by the build_vocab_from_iterator function
  3. Create csv data for train and test from document data by _create_data_from_iterator.
  4. Pass word data, train (test) data, and train (test) label to TextClassificationDataset class, instantiate them, and return them all together. The download_from_url function is a function that downloads the file defined for Google Drive. Finally, let's take a look at the TextClassificationDataset class.
class TextClassificationDataset(torch.utils.data.Dataset):
    """Defines an abstract text classification datasets.
       Currently, we only support the following datasets:

             - AG_NEWS
             - SogouNews
             - DBpedia
             - YelpReviewPolarity
             - YelpReviewFull
             - YahooAnswers
             - AmazonReviewPolarity
             - AmazonReviewFull

    """

[docs]    def __init__(self, vocab, data, labels):
        """Initiate text-classification dataset.

        Arguments:
            vocab: Vocabulary object used for dataset.
            data: a list of label/tokens tuple. tokens are a tensor after
                numericalizing the string tokens. label is an integer.
                [(label1, tokens1), (label2, tokens2), (label2, tokens3)]
            label: a set of the labels.
                {label1, label2}

        Examples:
            See the examples in examples/text_classification/

        """

        super(TextClassificationDataset, self).__init__()
        self._data = data
        self._labels = labels
        self._vocab = vocab


    def __getitem__(self, i):
        return self._data[i]

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        for x in self._data:
            yield x

    def get_labels(self):
        return self._labels

    def get_vocab(self):
        return self._vocab

You can see that it is a class for retrieving each data, not for processing new data. As you can see from the _setup_datasets function and the TextClassificationDataset class, the dataset is stored state converted to N-gram rather than a raw document. Therefore, if you want to use a data format other than N-gram, you need to write the process yourself based on the data saved in .data or the data downloaded from the url described in URLS.

At the end

Information that is difficult to understand just by printing can be understood by tracing the source code. I would like to continue reading the source code and compiling information.

Recommended Posts

Document classification with toch text from PyTorch
Document classification with Sentence Piece
Unsupervised text classification with Doc2Vec and k-means
[PyTorch] Introduction to document classification using BERT
Automatic document generation from docstring with sphinx
Extract Japanese text from PDF with PDFMiner
Extract text from PowerPoint with Python! (Compatible with tables)
Play with PyTorch
[PyTorch] Introduction to Japanese document classification using BERT
Challenge text classification by Naive Bayes with sklearn
Cross-validation with PyTorch
Beginning with PyTorch
Wav file generation from numeric text with python
Extracted text from image
Text mining with Python-Scraping-
Install torch-scatter with PyTorch 1.7
Pythonbrew with Sublime Text
Image classification with self-made neural network by Keras and PyTorch