In dem zuvor geschriebenen Verfahren zum Lernen und Ableiten des englisch-japanischen Übersetzungsmodells des Transformators mit CloudTPU wurde das englisch-japanische Übersetzungsmodell des Transformators mit CloudTPU gelernt und abgeleitet. Es war. Dieses Mal werde ich erklären, wie ein von Cloud TPU trainierter Transformator in einem lokalen Docker-Container ausgeführt wird. Der Code ist hier. https://github.com/yolo-kiyoshi/transformer_python_exec
Angenommen, Dateien befinden sich lokal in der folgenden Verzeichnisstruktur.
Verzeichnisaufbau
bucket
├── training/
│ └── transformer_ende/
│ ├── checkpoint
│ ├── model.ckpt-****.data-00000-of-00001
│ ├── model.ckpt-****.index
│ └── model.ckpt-****.meta
└── transformer/
└── vocab.translate_jpen.****.subwords
Klonen Sie das Repository.
git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git
Verzeichnisaufbau
.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│ └── transformer_ende/
└── transformer/
Laden Sie die Anmeldeinformationsdatei des Dienstkontos (json) herunter und legen Sie sie im selben Verzeichnis wie README.md ab.
Dupliziere und benenne .env.sample
um, um .env
zu erstellen.
.env
#Beschreiben Sie den Pfad der oben platzierten Anmeldeinformationsdatei
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
#Gleiche Einstellungen wie beim Lernen mit CloudTPU
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer
Nachdem Sie den folgenden Befehl ausgeführt haben, können Sie Jupyter lab betreiben, indem Sie auf http: // localhost: 8080 / lab zugreifen.
docker-compose up -d
Notebook
Laden Sie den Satz von "Checkpoint" -Dateien und "Vocal" -Dateien, die während des Transformator-Lernprozesses erstellt wurden, von GCS lokal herunter.
#Methode zum Herunterladen von Dateien von GCS(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
"""Downloads a blob from the bucket."""
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
print('Blob {} downloaded to {}.'.format(
source_blob_name,
destination_file_name))
#Siehe Erfassungsmethode für GCS-Dateilisten
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
"""Lists all the blobs in the bucket that begin with the prefix."""
storage_client = storage.Client()
# Note: Client.list_blobs requires at least package version 1.17.0.
blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)
file_list = [blob.name for blob in blobs if search_path in blob.name]
return file_list
#Umgebungsvariablen festlegen
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']
#Pfad der Prüfpunktdatei
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
#Laden Sie die Checkpoint-Datei von GCS herunter
download_blob(BUDGET_NAME, src_file_name, dist_file_name)
#Letzte Prüfpunktsequenz aus Prüfpunktdatei(prefix)Bekommen
import re
with open(dist_file_name) as f:
l = f.readlines(1)
ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
#Rufen Sie die Dateiliste mit dem neuesten Prüfpunkt von GCS ab
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)
# checkpoint.Laden Sie eine Reihe von Variablen herunter
for ckpt_file in ckpt_file_list:
download_blob(BUDGET_NAME, ckpt_file, ckpt_file)
#Rufen Sie den Vokabeldateipfad von GCS ab
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
#Laden Sie die Vokabeldatei von GCS herunter
download_blob(BUDGET_NAME, vocab_file, vocab_file)
Laden Sie das Transformatormodell basierend auf den von GCS heruntergeladenen Transformator-Trainingsergebnissen.
#Initialisieren
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys
import pickle
import numpy as np
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
#Vorverarbeitung&Verwenden Sie denselben Klassennamen wie PROBLE, der beim Lernen definiert wurde
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
@property
def approx_vocab_size(self):
return 2**13
enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)
from functools import wraps
import time
def stop_watch(func) :
@wraps(func)
def wrapper(*args, **kargs) :
start = time.time()
print(f'{func.__name__} started ...')
result = func(*args,**kargs)
elapsed_time = time.time() - start
print(f'elapsed_time:{elapsed_time}')
print(f'{func.__name__} completed')
return result
return wrapper
@stop_watch
def translate(inputs):
encoded_inputs = encode(inputs)
with tfe.restore_variables_on_create(ckpt_path):
model_output = translate_model.infer(features=encoded_inputs)["outputs"]
return decode(model_output)
def encode(input_str, output_str=None):
"""Input str to features dict, ready for inference"""
inputs = encoders["inputs"].encode(input_str) + [1]
batch_inputs = tf.reshape(inputs, [1, -1, 1])
return {"inputs": batch_inputs}
def decode(integers):
"""List of ints to str"""
integers = list(np.squeeze(integers))
if 1 in integers:
integers = integers[:integers.index(1)]
return encoders["inputs"].decode(np.squeeze(integers))
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)
Schliessen Sie mit dem geladenen Transformatormodell. Bei lokaler Ausführung dauert ein Satz etwa 30 Sekunden.
inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)
Ergebnis
>Meine Katze ist sehr süß.