Der Implementierungsablauf für die Klassifizierung von Dokumenten mithilfe von torchtext wird zusammen mit Offizielles Tutorial erläutert. Auch die Google Colabolatry, die dem offiziellen Tutorial beiliegt, gibt einen Fehler aus. Ich werde den Code veröffentlichen, nachdem ich den Teil korrigiert habe. Abschließend werde ich den Quellcode von torchtext.datasets.text_classification erläutern.
Google Colabolatry
Grundbegriffe für die Verarbeitung natürlicher Sprache wie N-Gramm
Bei der Klassifizierung von Dokumenten mit torchtext ist der Implementierungsablauf wie folgt. Wir werden uns den Code im nächsten Abschnitt ansehen, daher gibt diese Theorie nur einen Überblick.
Lassen Sie uns den obigen Ablauf mit dem Code im Tutorial überprüfen.
!pip install torch<=1.2.0
!pip install torchtext
%matplotlib inline
Wenn Sie es so ausführen, wie es ist, tritt beim Importieren des später beschriebenen Moduls der folgende Fehler auf.
from torchtext.datasets import text_classification
ImportError: cannot import name 'text_classification'
Der richtige Code sieht folgendermaßen aus: Außerdem muss die Laufzeit möglicherweise aufgrund von Änderungen in der torchtext-Version initialisiert werden. In diesem Fall führen Sie einfach die Neustart-Laufzeit aus und führen Sie die Zellen erneut von oben nach unten aus (Sie müssen die Neustart-Laufzeit nach der zweiten Pip-Installation nicht drücken).
!pip install torch<=1.2.0
!pip install torchtext==0.5
%matplotlib inline
Die Ursache ist die Version von torchtext. Wenn Sie eine Pip-Installation durchführen, ohne etwas anzugeben, wird 0.3.1 installiert. Da text_classification in 0.4 oder höher implementiert ist, kann es nicht wie in 0.3 verwendet werden. Oben ist es auf 0,5 festgelegt, aber wenn es 0,4 oder höher ist, gibt es kein Problem.
import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Es hat einen einfachen Fluss von Einbettung → linear. In init_weight werden die Gewichte mit den aus der Gleichverteilung generierten Gewichten initialisiert.
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
def generate_batch(batch):
label = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets = [0] + [len(entry) for entry in text]
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text = torch.cat(text)
return text, offsets, label
from torch.utils.data import DataLoader
def train_func(sub_train_):
# Train the model
train_loss = 0
train_acc = 0
data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=generate_batch)
for i, (text, offsets, cls) in enumerate(data):
optimizer.zero_grad()
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss = criterion(output, cls)
train_loss += loss.item()
loss.backward()
optimizer.step()
train_acc += (output.argmax(1) == cls).sum().item()
# Adjust the learning rate
scheduler.step()
return train_loss / len(sub_train_), train_acc / len(sub_train_)
def test(data_):
loss = 0
acc = 0
data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
for text, offsets, cls in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
with torch.no_grad():
output = model(text, offsets)
loss = criterion(output, cls)
loss += loss.item()
acc += (output.argmax(1) == cls).sum().item()
return loss / len(data_), acc / len(data_)
Wenn Sie richtig lernen, können Sie eine Genauigkeit von 0,9 oder höher erreichen.
import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
random_split(train_dataset, [train_len, len(train_dataset) - train_len])
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train_func(sub_train_)
valid_loss, valid_acc = test(sub_valid_)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
In TORCHTEXT.DATASETS.TEXT_CLASSIFICATION wird die Verarbeitung durchgeführt, um die erforderlichen Daten buchstäblich bereitzustellen. Im Gegenteil, es wird keine andere Operation ausgeführt. Mit anderen Worten, das Ziel dieses Moduls ist es, die Daten zu formatieren, die für das Training für verschiedene Datensätze erforderlich sind. Daher konzentrieren wir uns dieses Mal auf den Ablauf der Bereitstellung von Zug- und Testdatensätzen. Der in der folgenden Diskussion bereitgestellte Quellcode lautet hier. Zuerst werde ich den folgenden Code erneut veröffentlichen.
if not os.path.isdir('./.data'):
os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Hier sehen Sie, dass ein Verzeichnis mit dem Namen .data erstellt wird und dieses Verzeichnis als Root zum Generieren von Zug- und Testdatensätzen verwendet wird. Dies allein hat jedoch verschiedene unklare Punkte, einschließlich .data. Lesen wir also den Code und sehen uns eine genauere Verarbeitung an.
Einige Daten werden für die Klassifizierung von Dokumenten bereitgestellt. Die derzeit bereitgestellten Daten lauten wie folgt.
Wenn Sie alle Daten direkt abrufen möchten, können Sie sie von der in der URLS-Variablen beschriebenen URL herunterladen.
URLS = {
'AG_NEWS':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
'SogouNews':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
'DBpedia':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
'YelpReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
'YelpReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
'YahooAnswers':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
'AmazonReviewPolarity':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
'AmazonReviewFull':
'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}
Lassen Sie uns nun die Verarbeitung von Daten über den Quellcode verfolgen. Das erste, was getan wird, ist die Definition der Funktion.
def AG_NEWS(*args, **kwargs):
""" Defines AG_NEWS datasets.
The labels includes:
- 1 : World
- 2 : Sports
- 3 : Business
- 4 : Sci/Tech
Create supervised learning dataset: AG_NEWS
Separately returns the training and test dataset
Arguments:
root: Directory where the datasets are saved. Default: ".data"
ngrams: a contiguous sequence of n items from s string text.
Default: 1
vocab: Vocabulary used for dataset. If None, it will generate a new
vocabulary based on the train data set.
include_unk: include unknown token in the data (Default: False)
Examples:
>>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3)
"""
return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)
Sie können sehen, dass die formatierten Daten mit der Funktion _setup_datasets zurückgegeben werden. Von nun an wird nur noch AG_NEWS als Ziel ausgewählt, aber die gleiche Verarbeitung wird für andere Datensätze durchgeführt. Registrieren Sie als Nächstes die definierte Funktion in der Variablen DATASETS im Diktatformat.
DATASETS = {
'AG_NEWS': AG_NEWS,
'SogouNews': SogouNews,
'DBpedia': DBpedia,
'YelpReviewPolarity': YelpReviewPolarity,
'YelpReviewFull': YelpReviewFull,
'YahooAnswers': YahooAnswers,
'AmazonReviewPolarity': AmazonReviewPolarity,
'AmazonReviewFull': AmazonReviewFull
}
Darüber hinaus speichert die Variable LABELS die Beschriftungsinformationen für jeden Datensatz im Diktatformat.
LABELS = {
'AG_NEWS': {1: 'World',
2: 'Sports',
3: 'Business',
4: 'Sci/Tech'},
}
Obwohl hier weggelassen, werden andere Bezeichnungen als AG_NEWS im selben Format gespeichert. Da die Funktion im Diktatformat mit der obigen Variablen DATASETS registriert ist, beziehen sich die folgenden beiden auf dasselbe.
text_classification.DATASETS['AG_NEWS']
text_classification.AG_NEWS
Überprüfen Sie die Verarbeitung der Daten anhand der Funktion _setup_datasets.
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
dataset_tar = download_from_url(URLS[dataset_name], root=root)
extracted_files = extract_archive(dataset_tar)
for fname in extracted_files:
if fname.endswith('train.csv'):
train_csv_path = fname
if fname.endswith('test.csv'):
test_csv_path = fname
if vocab is None:
logging.info('Building Vocab based on {}'.format(train_csv_path))
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
logging.info('Vocab has {} entries'.format(len(vocab)))
logging.info('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
logging.info('Creating testing data')
test_data, test_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels don't match")
return (TextClassificationDataset(vocab, train_data, train_labels),
TextClassificationDataset(vocab, test_data, test_labels))
Die Hauptverarbeitung ist wie folgt.
class TextClassificationDataset(torch.utils.data.Dataset):
"""Defines an abstract text classification datasets.
Currently, we only support the following datasets:
- AG_NEWS
- SogouNews
- DBpedia
- YelpReviewPolarity
- YelpReviewFull
- YahooAnswers
- AmazonReviewPolarity
- AmazonReviewFull
"""
[docs] def __init__(self, vocab, data, labels):
"""Initiate text-classification dataset.
Arguments:
vocab: Vocabulary object used for dataset.
data: a list of label/tokens tuple. tokens are a tensor after
numericalizing the string tokens. label is an integer.
[(label1, tokens1), (label2, tokens2), (label2, tokens3)]
label: a set of the labels.
{label1, label2}
Examples:
See the examples in examples/text_classification/
"""
super(TextClassificationDataset, self).__init__()
self._data = data
self._labels = labels
self._vocab = vocab
def __getitem__(self, i):
return self._data[i]
def __len__(self):
return len(self._data)
def __iter__(self):
for x in self._data:
yield x
def get_labels(self):
return self._labels
def get_vocab(self):
return self._vocab
Sie können sehen, dass es sich um eine Klasse zum Abrufen der einzelnen Daten handelt, nicht zum Verarbeiten neuer Daten. Wie Sie der Funktion _setup_datasets und der Klasse TextClassificationDataset entnehmen können, wird das Dataset in das N-Gramm und den gespeicherten Status anstatt in das Rohdokument konvertiert. Wenn Sie ein anderes Datenformat als N-Gramm verwenden möchten, müssen Sie daher Ihre eigene Verarbeitung basierend auf den in .data gespeicherten Daten oder den Daten schreiben, die von der in URLS beschriebenen URL heruntergeladen wurden.
Informationen, die nur durch Drucken schwer zu verstehen sind, können durch Nachverfolgen des Quellcodes verstanden werden. Ich möchte den Quellcode weiter lesen und Informationen zusammenstellen.
Recommended Posts