[PYTHON] Classification des documents avec texte toch de PyTorch

introduction

Le flux d'implémentation pour classer les documents à l'aide de torchtext est expliqué avec Official tutorial. De plus, le Google Colabolatry qui accompagne le tutoriel officiel donne une erreur. Je posterai le code après avoir corrigé la partie qui est. Enfin, j'expliquerai le code source de torchtext.datasets.text_classification.

Environnement de développement

Google Colabolatry

Connaissance préalable

Termes de base pour le traitement du langage naturel tels que N-gramme

Flux de classification des documents

Lors de la classification de documents à l'aide de torchtext, le flux de mise en œuvre est le suivant. Nous examinerons le code dans la section suivante, donc cette théorie ne donne qu'un aperçu.

  1. pip install
  2. module d'importation
  3. Stockage du jeu de données, division en train et test
  4. Définition du modèle
  5. Conversion d'instance de modèle, définition de fonction pour la génération par lots
  6. Définition de la fonction pour le train, test
  7. Conduisez le train, testez

code

Vérifions le flux ci-dessus avec le code dans le didacticiel.

  1. pip install C'est presque hors de question, mais officiellement ce code est à l'origine d'une erreur. Plus précisément, la deuxième ligne est la cause.
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline

Si vous l'exécutez tel quel, l'erreur suivante se produira lors de l'importation du module décrit plus loin.

from torchtext.datasets import text_classification

ImportError: cannot import name 'text_classification'

Le code correct ressemble à ceci: En outre, il peut être nécessaire d'initialiser le runtime en raison de changements dans la version de torchtext. Dans ce cas, exécutez simplement le redémarrage et réexécutez les cellules de haut en bas (vous n'avez pas besoin d'appuyer sur le redémarrage après la deuxième installation de pip).

!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline

La cause est la version de torchtext. Si vous installez pip sans rien spécifier, la version 0.3.1 sera installée. Étant donné que text_classification est implémentée dans la version 0.4 ou ultérieure, elle ne peut pas être utilisée telle quelle dans la version 0.3. Dans ce qui précède, il est fixé à 0,5, mais s'il est égal ou supérieur à 0,4, il n'y a pas de problème.

2. module d'importation

import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os

3. Stockage du jeu de données, division en train et test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4. Définition du modèle

Il a un flux simple d'intégration → linéaire. Dans init_weight, les poids sont initialisés avec les poids générés à partir de la distribution uniforme.

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

5. Conversion d'instance de modèle, définition de fonction pour la génération par lots

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

6. Définition de la fonction pour le train, test

from torch.utils.data import DataLoader

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()
    
    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

7. Conduisez le train, testez

Si vous apprenez correctement, vous pouvez obtenir une précision de 0,9 ou plus.

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

Commentaire

Dans TORCHTEXT.DATASETS.TEXT_CLASSIFICATION, le traitement est effectué pour fournir littéralement les données nécessaires. Au contraire, aucune autre opération n'est effectuée. En d'autres termes, l'objectif de ce module est de formater les données nécessaires à la formation pour divers ensembles de données. Par conséquent, cette fois, nous nous concentrerons sur le flux de fourniture d'ensembles de données de train et de test. Le code source décrit dans la description suivante est ici. Tout d'abord, je republierai le code suivant.

3. Stockage du jeu de données, division en train et test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Ici, vous pouvez voir qu'un répertoire appelé .data est créé et que ce répertoire est utilisé comme racine pour générer des ensembles de données de train et de test. Cependant, cela seul a divers points peu clairs, y compris les données. Alors, lisons le code et voyons un traitement plus spécifique.

Données fournies par TORCHTEXT.DATASETS.TEXT_CLASSIFICATION

Certaines données sont fournies pour la classification des documents. Les données actuellement fournies sont les suivantes.

Si vous souhaitez obtenir directement chaque donnée, vous pouvez la télécharger à partir de l'url décrite dans la variable URLS.

URLS = {
    'AG_NEWS':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
    'SogouNews':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
    'DBpedia':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
    'YelpReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
    'YelpReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
    'YahooAnswers':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
    'AmazonReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
    'AmazonReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}

Maintenant, suivons réellement le traitement des données à travers le code source. La première chose à faire est la définition de la fonction.

def AG_NEWS(*args, **kwargs):
    """ Defines AG_NEWS datasets.
        The labels includes:
            - 1 : World
            - 2 : Sports
            - 3 : Business
            - 4 : Sci/Tech

    Create supervised learning dataset: AG_NEWS

    Separately returns the training and test dataset

    Arguments:
        root: Directory where the datasets are saved. Default: ".data"
        ngrams: a contiguous sequence of n items from s string text.
            Default: 1
        vocab: Vocabulary used for dataset. If None, it will generate a new
            vocabulary based on the train data set.
        include_unk: include unknown token in the data (Default: False)

    Examples:
        >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)

    """

    return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)

Vous pouvez voir que les données formatées sont renvoyées à l'aide de la fonction _setup_datasets. Désormais, seul AG_NEWS est ciblé, mais le même traitement est effectué pour les autres ensembles de données. Ensuite, enregistrez la fonction définie dans la variable DATASETS au format dict.

DATASETS = {
    'AG_NEWS': AG_NEWS,
    'SogouNews': SogouNews,
    'DBpedia': DBpedia,
    'YelpReviewPolarity': YelpReviewPolarity,
    'YelpReviewFull': YelpReviewFull,
    'YahooAnswers': YahooAnswers,
    'AmazonReviewPolarity': AmazonReviewPolarity,
    'AmazonReviewFull': AmazonReviewFull
}

En outre, la variable LABELS stocke les informations d'étiquette pour chaque ensemble de données au format dict.

LABELS = {
    'AG_NEWS': {1: 'World',
                2: 'Sports',
                3: 'Business',
                4: 'Sci/Tech'},
}

Bien que omis ici, les étiquettes autres que AG_NEWS sont stockées dans le même format. Puisque la fonction est enregistrée au format dict avec la variable DATASETS ci-dessus, les deux suivantes font référence à la même chose.

text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS

Vérifiez le traitement des données en regardant la fonction _setup_datasets.

def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))

Le traitement principal est le suivant.

  1. Enregistrez les données du document dans le répertoire spécifié par la fonction download_from_url.
  2. Créez les données de mot utilisées dans les données par la fonction build_vocab_from_iterator
  3. Créez des données csv pour l'apprentissage et le test à partir des données du document par _create_data_from_iterator.
  4. Passez les données de mot, les données d'entraînement (test) et l'étiquette d'entraînement (test) à la classe TextClassificationDataset, instanciez-les et renvoyez-les toutes ensemble. La fonction download_from_url est une fonction qui télécharge le fichier défini pour Google Drive. Enfin, jetons un coup d'œil à la classe TextClassificationDataset.
class TextClassificationDataset(torch.utils.data.Dataset):
    """Defines an abstract text classification datasets.
       Currently, we only support the following datasets:

             - AG_NEWS
             - SogouNews
             - DBpedia
             - YelpReviewPolarity
             - YelpReviewFull
             - YahooAnswers
             - AmazonReviewPolarity
             - AmazonReviewFull

    """

[docs]    def __init__(self, vocab, data, labels):
        """Initiate text-classification dataset.

        Arguments:
            vocab: Vocabulary object used for dataset.
            data: a list of label/tokens tuple. tokens are a tensor after
                numericalizing the string tokens. label is an integer.
                [(label1, tokens1), (label2, tokens2), (label2, tokens3)]
            label: a set of the labels.
                {label1, label2}

        Examples:
            See the examples in examples/text_classification/

        """

        super(TextClassificationDataset, self).__init__()
        self._data = data
        self._labels = labels
        self._vocab = vocab


    def __getitem__(self, i):
        return self._data[i]

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        for x in self._data:
            yield x

    def get_labels(self):
        return self._labels

    def get_vocab(self):
        return self._vocab

Vous pouvez voir qu'il s'agit d'une classe pour récupérer chaque donnée, pas pour traiter de nouvelles données. Comme vous pouvez le voir à partir de la fonction _setup_datasets et de la classe TextClassificationDataset, l'ensemble de données est converti en N-gramme et en état stocké plutôt que dans le document brut. Par conséquent, si vous souhaitez utiliser un format de données autre que N-gramme, vous devez écrire votre propre traitement en fonction des données enregistrées dans .data ou des données téléchargées à partir de l'url décrite dans URLS.

À la fin

Les informations difficiles à comprendre par simple impression peuvent être comprises en traçant le code source. Je souhaite continuer à lire le code source et à compiler des informations.

Recommended Posts

Classification des documents avec texte toch de PyTorch
Classification des documents avec une phrase
Classification de texte non supervisée avec Doc2Vec et k-means
[PyTorch] Introduction à la classification de documents à l'aide de BERT
Génération automatique de documents à partir de docstring avec sphinx
Extraire du texte japonais d'un PDF avec PDFMiner
Jouez avec PyTorch
[PyTorch] Introduction à la classification des documents japonais à l'aide de BERT
Classification des textes du défi par Naive Bayes avec sklearn
Validation croisée avec PyTorch
À partir de PyTorch
Texte extrait de l'image
Text mining avec Python-Scraping-
Installer la diffusion de la torche avec PyTorch 1.7
Pythonbrew avec Sublime Text
Classification d'images avec un réseau de neurones auto-fabriqué par Keras et PyTorch