-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_MountainCar.py
76 lines (55 loc) · 1.98 KB
/
run_MountainCar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Policy Gradient, Reinforcement Learning.
The cart pole example
View more on my tutorial page: https://morvanzhou.github.io/tutorials/
Using:
Tensorflow: 1.0
gym: 0.8.0
"""
import gym
from RL_brain import PolicyGradient
import matplotlib.pyplot as plt
DISPLAY_REWARD_THRESHOLD = -2000 # renders environment if total episode reward is greater then this threshold
# episode: 154 reward: -10667
# episode: 387 reward: -2009
# episode: 489 reward: -1006
# episode: 628 reward: -502
RENDER = False # rendering wastes time
env = gym.make('MountainCar-v0')
env.seed(1) # reproducible, general Policy gradient has high variance
env = env.unwrapped
print(env.action_space)
print(env.observation_space)
print(env.observation_space.high)
print(env.observation_space.low)
RL = PolicyGradient(
n_actions=env.action_space.n,
n_features=env.observation_space.shape[0],
learning_rate=0.02,
reward_decay=0.995,
# output_graph=True,
)
for i_episode in range(1000):
observation = env.reset()
while True:
if RENDER: env.render()
action = RL.choose_action(observation)
observation_, reward, done, info = env.step(action) # reward = -1 in all cases
RL.store_transition(observation, action, reward)
if done:
# calculate running reward
ep_rs_sum = sum(RL.ep_rs)
if 'running_reward' not in globals():
running_reward = ep_rs_sum
else:
running_reward = running_reward * 0.99 + ep_rs_sum * 0.01
if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True # rendering
print("episode:", i_episode, " reward:", int(running_reward))
vt = RL.learn() # train
if i_episode == 30:
plt.plot(vt) # plot the episode vt
plt.xlabel('episode steps')
plt.ylabel('normalized state-action value')
plt.show()
break
observation = observation_