[PYTHON] I want to use a network defined by myself in PPO2 of Stable Baselines

In Until you clear Super Mario Bros. 1-1 using Stable Baselines, you learned using PPO2 provided in Stable Baselines.

PPO2 just gives the name'CNNPolicy', and I'm not sure exactly what kind of architecture network is used (although there is a description that it conforms to the original PPO paper). Also, I didn't know what to do if I wanted to modify the network, so I followed Original Code.


class PPO2(ActorCriticRLModel):
    def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5,
                 max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, cliprange_vf=None,
                 verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None,
                 full_tensorboard_log=False, seed=None, n_cpu_tf_sess=None):
        super().__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
                         _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs,
                         seed=seed, n_cpu_tf_sess=n_cpu_tf_sess)

Keywords such as'CNNPolicy' correspond to this argument policy. However, it is not explicitly used within the init of this class, only within the parent class init.

Let's go see the parent class ActorCriticRLModel. It is located in stable-baselines/stable_baselines/common/base_class.py.


class ActorCriticRLModel(BaseRLModel):
    def __init__(self, policy, env, _init_setup_model, verbose=0, policy_base=ActorCriticPolicy,
                 requires_vec_env=False, policy_kwargs=None, seed=None, n_cpu_tf_sess=None):
        super(ActorCriticRLModel, self).__init__(policy, env, verbose=verbose, requires_vec_env=requires_vec_env,
                                                 policy_base=policy_base, policy_kwargs=policy_kwargs,
                                                 seed=seed, n_cpu_tf_sess=n_cpu_tf_sess)

Again, it's used in the parent class init. The parent class BaseRLModel is defined as an Abstract Class in the same file.


class BaseRLModel(ABC):
    def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base,
                 policy_kwargs=None, seed=None, n_cpu_tf_sess=None):
        if isinstance(policy, str) and policy_base is not None:
            self.policy = get_policy_from_name(policy_base, policy)
            self.policy = policy

Apparently it gets it with a function called get_policy_from_name. get_policy_from_name is defined in stable-baselines/stable_baselines/common/policies.py.


def get_policy_from_name(base_policy_type, name):
    if base_policy_type not in _policy_registry:
        raise ValueError("Error: the policy type {} is not registered!".format(base_policy_type))
    if name not in _policy_registry[base_policy_type]:
        raise ValueError("Error: unknown policy type {}, the only registed policy type are: {}!"
                         .format(name, list(_policy_registry[base_policy_type].keys())))
    return _policy_registry[base_policy_type][name]

We use a dictionary called policy_registry to call the class corresponding to policy. policy_registry is written just above it in the same file.


_policy_registry = {
    ActorCriticPolicy: {
        "CnnPolicy": CnnPolicy,
        "CnnLstmPolicy": CnnLstmPolicy,
        "CnnLnLstmPolicy": CnnLnLstmPolicy,
        "MlpPolicy": MlpPolicy,
        "MlpLstmPolicy": MlpLstmPolicy,
        "MlpLnLstmPolicy": MlpLnLstmPolicy,

I want to know the contents of CnnPolicy here, so I will go to see it. This is also defined in the same file.


class CnnPolicy(FeedForwardPolicy):
    Policy object that implements actor critic, using a CNN (the nature CNN)
    :param sess: (TensorFlow session) The current TensorFlow session
    :param ob_space: (Gym Space) The observation space of the environment
    :param ac_space: (Gym Space) The action space of the environment
    :param n_env: (int) The number of environments to run
    :param n_steps: (int) The number of steps to run for each environment
    :param n_batch: (int) The number of batch to run (n_envs * n_steps)
    :param reuse: (bool) If the policy is reusable or not
    :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction

    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
        super(CnnPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
                                        feature_extraction="cnn", **_kwargs)

Since it seems that the entity is not written here, look at the FeedForwardPolicy of the parent class.


class FeedForwardPolicy(ActorCriticPolicy):

    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, layers=None, net_arch=None,
                 act_fun=tf.tanh, cnn_extractor=nature_cnn, feature_extraction="cnn", **kwargs):
        super(FeedForwardPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse,
                                                scale=(feature_extraction == "cnn"))

        self._kwargs_check(feature_extraction, kwargs)

        if layers is not None:
            warnings.warn("Usage of the `layers` parameter is deprecated! Use net_arch instead "
                          "(it has a different semantics though).", DeprecationWarning)
            if net_arch is not None:
                warnings.warn("The new `net_arch` parameter overrides the deprecated `layers` parameter!",

        if net_arch is None:
            if layers is None:
                layers = [64, 64]
            net_arch = [dict(vf=layers, pi=layers)]

        with tf.variable_scope("model", reuse=reuse):
            if feature_extraction == "cnn":
                pi_latent = vf_latent = cnn_extractor(self.processed_obs, **kwargs)
                pi_latent, vf_latent = mlp_extractor(tf.layers.flatten(self.processed_obs), net_arch, act_fun)

            self._value_fn = linear(vf_latent, 'vf', 1)

            self._proba_distribution, self._policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)


    def step(self, obs, state=None, mask=None, deterministic=False):
        if deterministic:
            action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
            action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
        return action, value, self.initial_state, neglogp

    def proba_step(self, obs, state=None, mask=None):
        return self.sess.run(self.policy_proba, {self.obs_ph: obs})

    def value(self, obs, state=None, mask=None):
        return self.sess.run(self.value_flat, {self.obs_ph: obs})

I finally found something like that. The format is such that the features extracted using the feature extraction network are flowed into the value estimation network and the action probability output network. The following parts define the network for feature extraction.


            if feature_extraction == "cnn":
                pi_latent = vf_latent = cnn_extractor(self.processed_obs, **kwargs)
                pi_latent, vf_latent = mlp_extractor(tf.layers.flatten(self.processed_obs), net_arch, act_fun)

The cnn_extractor function is an argument, which defaults to nature_cnn. The nature_cnn function is also defined in the same file.


def nature_cnn(scaled_images, **kwargs):
    CNN from Nature paper.
    :param scaled_images: (TensorFlow Tensor) Image input placeholder
    :param kwargs: (dict) Extra keywords parameters for the convolutional layers of the CNN
    :return: (TensorFlow Tensor) The CNN output layer
    activ = tf.nn.relu
    layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=8, stride=4, init_scale=np.sqrt(2), **kwargs))
    layer_2 = activ(conv(layer_1, 'c2', n_filters=64, filter_size=4, stride=2, init_scale=np.sqrt(2), **kwargs))
    layer_3 = activ(conv(layer_2, 'c3', n_filters=64, filter_size=3, stride=1, init_scale=np.sqrt(2), **kwargs))
    layer_3 = conv_to_fc(layer_3)
    return activ(linear(layer_3, 'fc1', n_hidden=512, init_scale=np.sqrt(2)))

I finally found it. Instead of this, it seems that you should create a function that returns the network structure you want to use.

That's why I tried Mario 1-1 using a network with Self-Attention. image.png

However, there was no particular improvement in learning speed.

Recommended Posts

I want to use a network defined by myself in PPO2 of Stable Baselines
I want to print in a comprehension
I want to sort a list in the order of other lists
I want to color a part of an Excel string in Python
I want to create a window in Python
I want to use complicated four arithmetic operations in the IF statement of the Django template! → Use a custom template
I want to use Python in the environment of pyenv + pipenv on Windows 10
I want to set a life cycle in the task definition of ECS
I want to see a list of WebDAV files in the Requests module
I want to embed a variable in a Python string
I want to easily implement a timeout in python
I want to transition with a button in flask
I want to use self in Backpropagation (tf.custom_gradient) (tensorflow)
I want to write in Python! (2) Let's write a test
I want to randomly sample a file in Python
I want to work with a robot in python.
I want to install a package of Php Redis
I want to use the R dataset in python
I want to use a python data source in Re: Dash to get query results
The story of Linux that I want to teach myself half a year ago
I want to start a lot of processes from python
I want to use only the normalization process of SudachiPy
NikuGan ~ I want to see a lot of delicious meat! !!
I want to use a virtual environment with jupyter notebook!
I want to make input () a nice complement in python
[Python] A memo of frequently used phrases (by myself) in Python scripts
I want to use a wildcard that I want to shell with Python remove
Comparison of GCP computing services [I want to use it serverless]
I want to create a pipfile and reflect it in docker
I want to know the population of each country in the world.
I tried to embed a protein-protein interaction network in hyperbolic space with Poincarē embeding of gensim
I want to be healed by Mia Nanasawa's image. In such a case, hit the Twitter API ♪
I want to revive the legendary Nintendo combination by making full use of AI and HR Tech!
I want to change the color by clicking the scatter point in matplotlib
(Matplotlib) I want to draw a graph with a size specified in pixels
I want to be cursed by a pretty girl every time I sudo! !!
I want to batch convert the result of "string" .split () in Python
I want to explain the abstract class (ABCmeta) of Python in detail.
I want to use the Django Debug Toolbar in my Ajax application
[Google Colab] I want to display multiple images side by side in tiles
I want to use the Qore SDK to predict the success of NBA players
I want to leave an arbitrary command in the command history of Shell
A memorandum because I stumbled on trying to use MeCab in Python
I made a program to check the size of a file in Python
I tried to display the altitude value of DTM in a graph
I tried to verify the result of A / B test by chi-square test
Python: I want to measure the processing time of a function neatly
[Question] I want to scrape a character string surrounded by unique tags!
I tried to implement a card game of playing cards in Python
I want to do a monkey patch only partially safely in Python
I want to use jar from python
I want to build a Python environment
I want to use Linux on mac
I want to use IPython Qt Console
I tried to implement PPO in Python
I want to embed Matplotlib in PySimpleGUI
I want to output a path diagram of distributed covariance structure analysis (SEM) by linking Python and R.
I want to display only different lines of a text file with diff
I want to create a priority queue that can be updated in Python (2.7)
I made an appdo command to execute a command in the context of the app
[Azure] I tried to create a Linux virtual machine in Azure of Microsoft Learn