[PYTHON] Dokumentklassifizierung mit toch Text von PyTorch

Einführung

Der Implementierungsablauf für die Klassifizierung von Dokumenten mithilfe von torchtext wird zusammen mit Offizielles Tutorial erläutert. Auch die Google Colabolatry, die dem offiziellen Tutorial beiliegt, gibt einen Fehler aus. Ich werde den Code veröffentlichen, nachdem ich den Teil korrigiert habe. Abschließend werde ich den Quellcode von torchtext.datasets.text_classification erläutern.

Entwicklungsumgebung

Google Colabolatry

Vorherige Kenntniss

Grundbegriffe für die Verarbeitung natürlicher Sprache wie N-Gramm

Dokumentklassifizierungsablauf

Bei der Klassifizierung von Dokumenten mit torchtext ist der Implementierungsablauf wie folgt. Wir werden uns den Code im nächsten Abschnitt ansehen, daher gibt diese Theorie nur einen Überblick.

  1. pip install
  2. Modul importieren
  3. Speicherung des Datensatzes, Aufteilung in Zug und Test
  4. Modelldefinition
  5. Modellinstanzkonvertierung, Funktionsdefinition für die Stapelgenerierung
  6. Funktionsdefinition für Zug, Test
  7. Zug fahren, testen

Code

Lassen Sie uns den obigen Ablauf mit dem Code im Tutorial überprüfen.

  1. pip install Es kommt fast nicht in Frage, aber offiziell verursacht dieser Code einen Fehler. Insbesondere ist die zweite Zeile die Ursache.
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline

Wenn Sie es so ausführen, wie es ist, tritt beim Importieren des später beschriebenen Moduls der folgende Fehler auf.

from torchtext.datasets import text_classification

ImportError: cannot import name 'text_classification'

Der richtige Code sieht folgendermaßen aus: Außerdem muss die Laufzeit möglicherweise aufgrund von Änderungen in der torchtext-Version initialisiert werden. In diesem Fall führen Sie einfach die Neustart-Laufzeit aus und führen Sie die Zellen erneut von oben nach unten aus (Sie müssen die Neustart-Laufzeit nach der zweiten Pip-Installation nicht drücken).

!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline

Die Ursache ist die Version von torchtext. Wenn Sie eine Pip-Installation durchführen, ohne etwas anzugeben, wird 0.3.1 installiert. Da text_classification in 0.4 oder höher implementiert ist, kann es nicht wie in 0.3 verwendet werden. Oben ist es auf 0,5 festgelegt, aber wenn es 0,4 oder höher ist, gibt es kein Problem.

2. Modul importieren

import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os

3. Speicherung des Datensatzes, Aufteilung in Zug und Test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4. Modelldefinition

Es hat einen einfachen Fluss von Einbettung → linear. In init_weight werden die Gewichte mit den aus der Gleichverteilung generierten Gewichten initialisiert.

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

5. Modellinstanzkonvertierung, Funktionsdefinition für die Stapelgenerierung

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

6. Funktionsdefinition für Zug, Test

from torch.utils.data import DataLoader

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()
    
    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

7. Zug fahren, testen

Wenn Sie richtig lernen, können Sie eine Genauigkeit von 0,9 oder höher erreichen.

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

Kommentar

In TORCHTEXT.DATASETS.TEXT_CLASSIFICATION wird die Verarbeitung durchgeführt, um die erforderlichen Daten buchstäblich bereitzustellen. Im Gegenteil, es wird keine andere Operation ausgeführt. Mit anderen Worten, das Ziel dieses Moduls ist es, die Daten zu formatieren, die für das Training für verschiedene Datensätze erforderlich sind. Daher konzentrieren wir uns dieses Mal auf den Ablauf der Bereitstellung von Zug- und Testdatensätzen. Der in der folgenden Diskussion bereitgestellte Quellcode lautet hier. Zuerst werde ich den folgenden Code erneut veröffentlichen.

3. Speicherung des Datensatzes, Aufteilung in Zug und Test

if not os.path.isdir('./.data'):
	os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Hier sehen Sie, dass ein Verzeichnis mit dem Namen .data erstellt wird und dieses Verzeichnis als Root zum Generieren von Zug- und Testdatensätzen verwendet wird. Dies allein hat jedoch verschiedene unklare Punkte, einschließlich .data. Lesen wir also den Code und sehen uns eine genauere Verarbeitung an.

Daten bereitgestellt von TORCHTEXT.DATASETS.TEXT_CLASSIFICATION

Einige Daten werden für die Klassifizierung von Dokumenten bereitgestellt. Die derzeit bereitgestellten Daten lauten wie folgt.

Wenn Sie alle Daten direkt abrufen möchten, können Sie sie von der in der URLS-Variablen beschriebenen URL herunterladen.

URLS = {
    'AG_NEWS':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
    'SogouNews':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
    'DBpedia':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
    'YelpReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
    'YelpReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
    'YahooAnswers':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
    'AmazonReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
    'AmazonReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}

Lassen Sie uns nun die Verarbeitung von Daten über den Quellcode verfolgen. Das erste, was getan wird, ist die Definition der Funktion.

def AG_NEWS(*args, **kwargs):
    """ Defines AG_NEWS datasets.
        The labels includes:
            - 1 : World
            - 2 : Sports
            - 3 : Business
            - 4 : Sci/Tech

    Create supervised learning dataset: AG_NEWS

    Separately returns the training and test dataset

    Arguments:
        root: Directory where the datasets are saved. Default: ".data"
        ngrams: a contiguous sequence of n items from s string text.
            Default: 1
        vocab: Vocabulary used for dataset. If None, it will generate a new
            vocabulary based on the train data set.
        include_unk: include unknown token in the data (Default: False)

    Examples:
        >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)

    """

    return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)

Sie können sehen, dass die formatierten Daten mit der Funktion _setup_datasets zurückgegeben werden. Von nun an wird nur noch AG_NEWS als Ziel ausgewählt, aber die gleiche Verarbeitung wird für andere Datensätze durchgeführt. Registrieren Sie als Nächstes die definierte Funktion in der Variablen DATASETS im Diktatformat.

DATASETS = {
    'AG_NEWS': AG_NEWS,
    'SogouNews': SogouNews,
    'DBpedia': DBpedia,
    'YelpReviewPolarity': YelpReviewPolarity,
    'YelpReviewFull': YelpReviewFull,
    'YahooAnswers': YahooAnswers,
    'AmazonReviewPolarity': AmazonReviewPolarity,
    'AmazonReviewFull': AmazonReviewFull
}

Darüber hinaus speichert die Variable LABELS die Beschriftungsinformationen für jeden Datensatz im Diktatformat.

LABELS = {
    'AG_NEWS': {1: 'World',
                2: 'Sports',
                3: 'Business',
                4: 'Sci/Tech'},
}

Obwohl hier weggelassen, werden andere Bezeichnungen als AG_NEWS im selben Format gespeichert. Da die Funktion im Diktatformat mit der obigen Variablen DATASETS registriert ist, beziehen sich die folgenden beiden auf dasselbe.

text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS

Überprüfen Sie die Verarbeitung der Daten anhand der Funktion _setup_datasets.

def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))

Die Hauptverarbeitung ist wie folgt.

  1. Speichern Sie die Dokumentdaten in dem von der Funktion download_from_url angegebenen Verzeichnis.
  2. Erstellen Sie die Wortdaten, die von der Funktion build_vocab_from_iterator in den Daten verwendet werden
  3. Erstellen Sie CSV-Daten für Zug und Test aus Dokumentdaten mit _create_data_from_iterator.
  4. Übergeben Sie Wortdaten, Zugdaten (Testdaten) und Zugetiketten (Testetiketten) an die TextClassificationDataset-Klasse, instanziieren Sie sie und geben Sie sie alle zusammen zurück. Die Funktion download_from_url ist eine Funktion, die die für Google Drive definierte Datei herunterlädt. Schauen wir uns zum Schluss die TextClassificationDataset-Klasse an.
class TextClassificationDataset(torch.utils.data.Dataset):
    """Defines an abstract text classification datasets.
       Currently, we only support the following datasets:

             - AG_NEWS
             - SogouNews
             - DBpedia
             - YelpReviewPolarity
             - YelpReviewFull
             - YahooAnswers
             - AmazonReviewPolarity
             - AmazonReviewFull

    """

[docs]    def __init__(self, vocab, data, labels):
        """Initiate text-classification dataset.

        Arguments:
            vocab: Vocabulary object used for dataset.
            data: a list of label/tokens tuple. tokens are a tensor after
                numericalizing the string tokens. label is an integer.
                [(label1, tokens1), (label2, tokens2), (label2, tokens3)]
            label: a set of the labels.
                {label1, label2}

        Examples:
            See the examples in examples/text_classification/

        """

        super(TextClassificationDataset, self).__init__()
        self._data = data
        self._labels = labels
        self._vocab = vocab


    def __getitem__(self, i):
        return self._data[i]

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        for x in self._data:
            yield x

    def get_labels(self):
        return self._labels

    def get_vocab(self):
        return self._vocab

Sie können sehen, dass es sich um eine Klasse zum Abrufen der einzelnen Daten handelt, nicht zum Verarbeiten neuer Daten. Wie Sie der Funktion _setup_datasets und der Klasse TextClassificationDataset entnehmen können, wird das Dataset in das N-Gramm und den gespeicherten Status anstatt in das Rohdokument konvertiert. Wenn Sie ein anderes Datenformat als N-Gramm verwenden möchten, müssen Sie daher Ihre eigene Verarbeitung basierend auf den in .data gespeicherten Daten oder den Daten schreiben, die von der in URLS beschriebenen URL heruntergeladen wurden.

Am Ende

Informationen, die nur durch Drucken schwer zu verstehen sind, können durch Nachverfolgen des Quellcodes verstanden werden. Ich möchte den Quellcode weiter lesen und Informationen zusammenstellen.

Recommended Posts

Dokumentklassifizierung mit toch Text von PyTorch
Dokumentenklassifizierung mit Satzstück
Unüberwachte Textklassifizierung mit Doc2Vec und k-means
[PyTorch] Einführung in die Dokumentklassifizierung mit BERT
Automatische Dokumentenerstellung aus Docstring mit Sphinx
Extrahieren Sie japanischen Text aus PDF mit PDFMiner
Spiele mit PyTorch
[PyTorch] Einführung in die Klassifizierung japanischer Dokumente mit BERT
Fordern Sie die Textklassifizierung von Naive Bayes mit sklearn heraus
Kreuzvalidierung mit PyTorch
Beginnend mit PyTorch
Extrahierter Text aus dem Bild
Text Mining mit Python-Scraping-
Installieren Sie Fackelstreuung mit PyTorch 1.7
Pythonbrew mit erhabenem Text
Bildklassifizierung mit selbst erstelltem neuronalen Netzwerk von Keras und PyTorch