ChainerRL-Kurzreferenz https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/hands_on/chainerrl/quickstart.html
Es wird davon ausgegangen, dass Sie die Stärkung des Lernens abgeschlossen haben 5. Erstellen Sie die folgenden Dateien unter Bezugnahme auf die Kurzreferenz.
train.py
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
env = gym.make('CartPole-v0')
print('observation space:', env.observation_space)
print('action space:', env.action_space)
obs = env.reset()
env.render()
print('initial observation:', obs)
action = env.action_space.sample()
obs, r, done, info = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)
class QFunction(chainer.Chain):
def __init__(self, obs_size, n_actions, n_hidden_channels=50):
super().__init__()
with self.init_scope():
self.l0 = L.Linear(obs_size, n_hidden_channels)
self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
self.l2 = L.Linear(n_hidden_channels, n_actions)
def __call__(self, x, test=False):
"""
Args:
x (ndarray or chainer.Variable): An observation
test (bool): a flag indicating whether it is in test mode
"""
h = F.tanh(self.l0(x))
h = F.tanh(self.l1(h))
return chainerrl.action_value.DiscreteActionValue(self.l2(h))
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)
# Set the discount factor that discounts future rewards.
gamma = 0.95
# Use epsilon-greedy for exploration
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
epsilon=0.3, random_action_func=env.action_space.sample)
# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)
# Since observations from CartPole-v0 is numpy.float64 while
# Chainer only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x: x.astype(np.float32, copy=False)
# Now create an agent that will interact with the environment.
agent = chainerrl.agents.DoubleDQN(
q_func, optimizer, replay_buffer, gamma, explorer,
replay_start_size=500, update_interval=1,
target_update_interval=100, phi=phi)
n_episodes = 200
max_episode_len = 200
for i in range(1, n_episodes + 1):
obs = env.reset()
reward = 0
done = False
R = 0 # return (sum of rewards)
t = 0 # time step
while not done and t < max_episode_len:
# Uncomment to watch the behaviour
# env.render()
action = agent.act_and_train(obs, reward)
obs, reward, done, _ = env.step(action)
R += reward
t += 1
if i % 10 == 0:
print('episode:', i,
'R:', R,
'statistics:', agent.get_statistics())
agent.stop_episode_and_train(obs, reward, done)
print('Finished.')
agent.save('agent')
env.close()
Wenn es richtig funktioniert, wird die Ausgabe wie folgt sein.
observation space: Box(4,)
action space: Discrete(2)
initial observation: [-0.04736688 -0.01970095 -0.00356997 0.01937746]
next observation: [-0.0477609 0.17547202 -0.00318242 -0.27442969]
reward: 1.0
done: False
info: {}
episode: 10 R: 54.0 statistics: [('average_q', 0.02431227437087797), ('average_loss', 0), ('n_updates', 0)]
episode: 20 R: 38.0 statistics: [('average_q', 0.6243046798922441), ('average_loss', 0.08867046155807262), ('n_updates', 378)]
episode: 30 R: 46.0 statistics: [('average_q', 2.2610338271644586), ('average_loss', 0.09550022600040467), ('n_updates', 784)]
episode: 40 R: 84.0 statistics: [('average_q', 5.323362298387771), ('average_loss', 0.16771472656243475), ('n_updates', 1399)]
episode: 50 R: 91.0 statistics: [('average_q', 9.851513830734694), ('average_loss', 0.19145745620246343), ('n_updates', 2351)]
episode: 60 R: 99.0 statistics: [('average_q', 14.207080180752635), ('average_loss', 0.22097823899753388), ('n_updates', 3584)]
episode: 70 R: 200.0 statistics: [('average_q', 17.49337381852232), ('average_loss', 0.18525375351216344), ('n_updates', 5285)]
episode: 80 R: 124.0 statistics: [('average_q', 18.933387631649587), ('average_loss', 0.1511605453710412), ('n_updates', 7063)]
episode: 90 R: 200.0 statistics: [('average_q', 19.55727598346719), ('average_loss', 0.167370220872378), ('n_updates', 8496)]
episode: 100 R: 200.0 statistics: [('average_q', 19.92113421424675), ('average_loss', 0.15092426599174535), ('n_updates', 10351)]
episode: 110 R: 161.0 statistics: [('average_q', 19.870179660112395), ('average_loss', 0.1369066775700466), ('n_updates', 12169)]
episode: 120 R: 200.0 statistics: [('average_q', 19.985680296882315), ('average_loss', 0.13667809001004586), ('n_updates', 13991)]
episode: 130 R: 200.0 statistics: [('average_q', 20.016279858512945), ('average_loss', 0.14053696154447365), ('n_updates', 15938)]
episode: 140 R: 180.0 statistics: [('average_q', 19.870299413261478), ('average_loss', 0.1270716956269478), ('n_updates', 17593)]
episode: 150 R: 200.0 statistics: [('average_q', 19.990808581945565), ('average_loss', 0.1228807602095278), ('n_updates', 19442)]
episode: 160 R: 130.0 statistics: [('average_q', 19.954955203815164), ('average_loss', 0.14701205384726732), ('n_updates', 21169)]
episode: 170 R: 133.0 statistics: [('average_q', 19.994069560095422), ('average_loss', 0.12502104946859763), ('n_updates', 22709)]
episode: 180 R: 200.0 statistics: [('average_q', 19.973195015705674), ('average_loss', 0.1227321977377075), ('n_updates', 24522)]
episode: 190 R: 200.0 statistics: [('average_q', 20.050942533128573), ('average_loss', 0.09264820379188309), ('n_updates', 26335)]
episode: 200 R: 191.0 statistics: [('average_q', 19.81062306392066), ('average_loss', 0.11778217212419012), ('n_updates', 28248)]
Finished.
Recommended Posts