[PYTHON] Renforcer l'apprentissage 6 First Chainer RL

Guide de référence rapide ChainerRL https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/hands_on/chainerrl/quickstart.html

On suppose que vous avez terminé le renforcement de l'apprentissage 5. Créez les fichiers suivants en vous référant à la référence rapide.

train.py


import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
env = gym.make('CartPole-v0')
print('observation space:', env.observation_space)
print('action space:', env.action_space)

obs = env.reset()
env.render()
print('initial observation:', obs)

action = env.action_space.sample()
obs, r, done, info = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)

class QFunction(chainer.Chain):

    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(obs_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)

    def __call__(self, x, test=False):
        """
        Args:
            x (ndarray or chainer.Variable): An observation
            test (bool): a flag indicating whether it is in test mode
        """
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)

optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

# Set the discount factor that discounts future rewards.
gamma = 0.95

# Use epsilon-greedy for exploration
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)

# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# Chainer only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x: x.astype(np.float32, copy=False)

# Now create an agent that will interact with the environment.
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size=500, update_interval=1,
    target_update_interval=100, phi=phi)

n_episodes = 200
max_episode_len = 200
for i in range(1, n_episodes + 1):
    obs = env.reset()
    reward = 0
    done = False
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while not done and t < max_episode_len:
        # Uncomment to watch the behaviour
        # env.render()
        action = agent.act_and_train(obs, reward)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
    if i % 10 == 0:
        print('episode:', i,
              'R:', R,
              'statistics:', agent.get_statistics())
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')
agent.save('agent')
env.close()

Si cela fonctionne correctement, la sortie sera la suivante.

observation space: Box(4,)
action space: Discrete(2)
initial observation: [-0.04736688 -0.01970095 -0.00356997  0.01937746]
next observation: [-0.0477609   0.17547202 -0.00318242 -0.27442969]
reward: 1.0
done: False
info: {}
episode: 10 R: 54.0 statistics: [('average_q', 0.02431227437087797), ('average_loss', 0), ('n_updates', 0)]
episode: 20 R: 38.0 statistics: [('average_q', 0.6243046798922441), ('average_loss', 0.08867046155807262), ('n_updates', 378)]
episode: 30 R: 46.0 statistics: [('average_q', 2.2610338271644586), ('average_loss', 0.09550022600040467), ('n_updates', 784)]
episode: 40 R: 84.0 statistics: [('average_q', 5.323362298387771), ('average_loss', 0.16771472656243475), ('n_updates', 1399)]
episode: 50 R: 91.0 statistics: [('average_q', 9.851513830734694), ('average_loss', 0.19145745620246343), ('n_updates', 2351)]
episode: 60 R: 99.0 statistics: [('average_q', 14.207080180752635), ('average_loss', 0.22097823899753388), ('n_updates', 3584)]
episode: 70 R: 200.0 statistics: [('average_q', 17.49337381852232), ('average_loss', 0.18525375351216344), ('n_updates', 5285)]
episode: 80 R: 124.0 statistics: [('average_q', 18.933387631649587), ('average_loss', 0.1511605453710412), ('n_updates', 7063)]
episode: 90 R: 200.0 statistics: [('average_q', 19.55727598346719), ('average_loss', 0.167370220872378), ('n_updates', 8496)]
episode: 100 R: 200.0 statistics: [('average_q', 19.92113421424675), ('average_loss', 0.15092426599174535), ('n_updates', 10351)]
episode: 110 R: 161.0 statistics: [('average_q', 19.870179660112395), ('average_loss', 0.1369066775700466), ('n_updates', 12169)]
episode: 120 R: 200.0 statistics: [('average_q', 19.985680296882315), ('average_loss', 0.13667809001004586), ('n_updates', 13991)]
episode: 130 R: 200.0 statistics: [('average_q', 20.016279858512945), ('average_loss', 0.14053696154447365), ('n_updates', 15938)]
episode: 140 R: 180.0 statistics: [('average_q', 19.870299413261478), ('average_loss', 0.1270716956269478), ('n_updates', 17593)]
episode: 150 R: 200.0 statistics: [('average_q', 19.990808581945565), ('average_loss', 0.1228807602095278), ('n_updates', 19442)]
episode: 160 R: 130.0 statistics: [('average_q', 19.954955203815164), ('average_loss', 0.14701205384726732), ('n_updates', 21169)]
episode: 170 R: 133.0 statistics: [('average_q', 19.994069560095422), ('average_loss', 0.12502104946859763), ('n_updates', 22709)]
episode: 180 R: 200.0 statistics: [('average_q', 19.973195015705674), ('average_loss', 0.1227321977377075), ('n_updates', 24522)]
episode: 190 R: 200.0 statistics: [('average_q', 20.050942533128573), ('average_loss', 0.09264820379188309), ('n_updates', 26335)]
episode: 200 R: 191.0 statistics: [('average_q', 19.81062306392066), ('average_loss', 0.11778217212419012), ('n_updates', 28248)]
Finished.

Recommended Posts

Renforcer l'apprentissage 6 First Chainer RL
Renforcer l'apprentissage 4 CartPole première étape
Renforcer l'apprentissage 8 Essayez d'utiliser l'interface utilisateur de Chainer
[Introduction] Renforcer l'apprentissage
Apprentissage par renforcement futur_2
Apprentissage par renforcement futur_1
Premier apprentissage profond ~ Lutte ~
Apprentissage amélioré 1 installation de Python
Renforcer l'apprentissage 3 Installation d'OpenAI
Renforcer l'apprentissage de la troisième ligne
Premier apprentissage profond ~ Préparation ~
Première solution d'apprentissage en profondeur ~
[Renforcer l'apprentissage] Tâche de bandit
Apprentissage amélioré Python + Unity (apprentissage)
Renforcer l'apprentissage 1 édition introductive
Renforcer l'apprentissage 18 Colaboratory + Acrobat + ChainerRL
Apprentissage amélioré 7 Sortie du journal des données d'apprentissage
Renforcer l'apprentissage 17 Colaboratory + CartPole + ChainerRL
Renforcer l'apprentissage 28 collaboratif + OpenAI + chainerRL
Renforcement de l'apprentissage 2 Installation de chainerrl
[Renforcer l'apprentissage] Suivi par multi-agents
Apprentissage amélioré à partir de Python
Renforcer l'apprentissage 20 Colaboratoire + Pendule + ChainerRL
Apprentissage par renforcement 5 Essayez de programmer CartPole?
Apprentissage par renforcement 9 Remodelage magique ChainerRL
Renforcer l'apprentissage Apprendre d'aujourd'hui
Apprentissage par renforcement profond 1 Introduction au renforcement de l'apprentissage
Apprentissage par renforcement profond 2 Mise en œuvre de l'apprentissage par renforcement
DeepMind Enhanced Learning Framework Acme
Apprentissage par renforcement: accélérer l'itération de la valeur
Renforcer l'apprentissage 21 Colaboratoire + Pendule + ChainerRL + A2C
TF2RL: bibliothèque d'apprentissage améliorée pour TensorFlow2.x
Apprentissage par renforcement 34 Créez des vidéos d'agent en continu
Renforcer l'apprentissage 13 Essayez Mountain_car avec ChainerRL.
Construction d'un environnement d'apprentissage amélioré Python + Unity
[Chainer] Apprentissage de XOR avec perceptron multicouche
Renforcer l'apprentissage 22 Colaboratory + CartPole + ChainerRL + A3C
Explorez le labyrinthe avec l'apprentissage augmenté
Mémorandum d'introduction au tutoriel d'apprentissage automatique de Chainer
Première reconnaissance faciale d'anime avec Chainer
Renforcer l'apprentissage 24 Colaboratory + CartPole + ChainerRL + ACER
Apprentissage par renforcement 3 Méthode de planification dynamique / méthode TD
Deep Strengthening Learning 3 Édition pratique: Briser des blocs
J'ai essayé l'apprentissage par renforcement avec PyBrain
Essayez l'apprentissage de la représentation commune avec le chainer
(python) Principes de base du chaînage de la bibliothèque d'apprentissage en profondeur
Apprenez en faisant! Apprentissage par renforcement profond_1