[PYTHON] How to run a trained transformer model locally on CloudTPU

In the previously written Procedure to learn and infer the transformer English-Japanese translation model with CloudTPU, the transformer English-Japanese translation model was learned with CloudTPU and inference was also performed. It was. This time, I will explain how to execute a transformer trained by CloudTPU in a local Docker container. The code is here. https://github.com/yolo-kiyoshi/transformer_python_exec

Premise

GCS, suppose files are located locally in the following directory structure.

GCS directory

Directory structure


bucket
├── training/
│   └── transformer_ende/
│       ├── checkpoint
│       ├── model.ckpt-****.data-00000-of-00001
│       ├── model.ckpt-****.index
│       └── model.ckpt-****.meta
└── transformer/
    └── vocab.translate_jpen.****.subwords

Local directory

Clone the repository.

git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git

Directory structure


.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│   └── transformer_ende/
└── transformer/

Preparation

Google Credential file

Download the service account Credential file (json) and place it in the same directory as the README.md.

Environment variable

Duplicate and rename .env.sample to create .env.

.env


#Describe the path of the Credential file placed above
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
#Same settings as when learning with CloudTPU
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer

Creating a Docker image and launching a container

After executing the following command, you can operate Jupyter lab by accessing http: // localhost: 8080 / lab.

docker-compose up -d

Notebook

Download transformer learning results from GCS

Download the set of checkpoint files and the vocal file created during the transformer learning process from GCS locally.

#Method to download file from GCS(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
    """Downloads a blob from the bucket."""
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    blob = bucket.blob(source_blob_name)

    blob.download_to_filename(destination_file_name)

    print('Blob {} downloaded to {}.'.format(
        source_blob_name,
        destination_file_name))

#Refer to GCS file list acquisition method
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
    """Lists all the blobs in the bucket that begin with the prefix."""
    
    storage_client = storage.Client()

    # Note: Client.list_blobs requires at least package version 1.17.0.
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)

    file_list = [blob.name for blob in blobs if search_path in blob.name]
    
    return file_list

#Set environment variables
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']

#checkpoint file path
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
#Download checkpoint file from GCS
download_blob(BUDGET_NAME, src_file_name, dist_file_name)

#Latest checkpoint sequence from checkpoint file(prefix)To get
import re
with open(dist_file_name) as f:
    l = f.readlines(1)
    ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
    ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
#Get the file list associated with the latest checkpoint from GCS
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)

# checkpoint.Download a set of variables
for ckpt_file in ckpt_file_list:
    download_blob(BUDGET_NAME, ckpt_file, ckpt_file)

#Get the vocab file path from GCS
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
#Download vocab file from GCS
download_blob(BUDGET_NAME, vocab_file, vocab_file)

Load trained transformer model

Load the transformer model based on the transformer training results downloaded from GCS.

#Initialization
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys


import pickle

import numpy as np

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry

#Preprocessing&Use the same Class name as PROBLE defined in learning
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**13

enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)

from functools import wraps
import time

def stop_watch(func) :
    @wraps(func)
    def wrapper(*args, **kargs) :
        start = time.time()
        print(f'{func.__name__} started ...')
        result = func(*args,**kargs)
        elapsed_time =  time.time() - start
        print(f'elapsed_time:{elapsed_time}')
        print(f'{func.__name__} completed')
        return result
    return wrapper

@stop_watch
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(ckpt_path):
        model_output = translate_model.infer(features=encoded_inputs)["outputs"]
    return decode(model_output)

def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]
    batch_inputs = tf.reshape(inputs, [1, -1, 1])
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)

inference

Infer with the loaded transformer model. When executed locally, one sentence takes about 30 seconds.

inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)

result


>My cat is very cute.

reference

Welcome to the Tensor2Tensor Colab

Recommended Posts

How to run a trained transformer model locally on CloudTPU
How to run Django on IIS on a Windows server
How to run matplotlib on heroku
How to test on a Django-authenticated page
How to run Cython on OSX Memo
How to run a Maya Python script
How to make a process thread run only on a specific CPU core
[NNabla] How to add a quantization layer to the middle layer of a trained model
How to run MeCab on Ubuntu 18.04 LTS Python
How to live a decent life on 2017 Windows
[PyTorch] Sample ⑧ ~ How to build a complex model ~
[2015/11/19] How to register a service locally using the python SDK on naoqi os
How to run a Django application on a Docker container (development and production environment)
How to deploy a Django application on Alibaba Cloud
How to install Linux on a 32bit UEFI PC
A memorandum on how to use keras.preprocessing.image in Keras
How to build a Django (python) environment on docker
How to run Self bot on Discord.py [Easy vandalism! ]
How to build a Python environment on amazon linux 2
How to use GitHub on a multi-person server without a password
How to use Fujifilm X-T3 as a webcam on Ubuntu 20.04
A memo on how to easily prepare a Linux exercise environment
How to build a new python virtual environment on Ubuntu
How to run a Python file at a Windows 10 command prompt
How to run a Python program from within a shell script
Don't lose to Ruby! How to run Python (Django) on Heroku
How to call a function
How to make a multiplayer online action game on Slack
Steps to learn & infer transformer English-Japanese translation model with CloudTPU
How to mount a Windows 10 directory on Ubuntu-Server 20.04 on VMware Workstation 15
How to register on pypi
A note on how to load a virtual environment in PyCharm
How to hack a terminal
How to register a package on PyPI (as of September 2017)
How to run Notepad ++ Python
How to install OpenCV on Cloud9 and run it in Python
A note on how to check the connection to the license server port
How easy is it to synthesize a drug on the market?
How to customize U-Boot with OSD335X on a custom board (memo)
Try to edit a new image using the trained StyleGAN2 model
How to run Jupyter and Spark on Mac with minimal settings
How to build a Python environment using Virtualenv on Ubuntu 18.04 LTS
How to make a .dylib library from a .a library on OSX (El Capitan)
Call dlm from python to run a time-varying coefficient regression model
How to deploy a Django app on heroku in just 5 minutes
How to deploy a web application on Alibaba Cloud as a freelancer
How to continue processing after returning a response on aiohttp Server
Building an environment to run ChainerMN on a GPU instance on AWS
How to make a Japanese-English translation
How to write a Python class
How to put a symbolic link
A memo to move Errbot locally
How to install mysql-connector-python on mac
How to use Dataiku on Windows
Preparing to run Flask on EC2
Notes on how to use pywinauto
How to make a slack bot
How to create a Conda package
How to install graph-tool on macOS
Run a Linux server on GCP
How to install VMware-Tools on Linux