forked from camall3n/visgrid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteract.py
42 lines (32 loc) · 1.06 KB
/
interact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import random
from visgrid.taxi.taxi import Taxi5x5
from visgrid.agents.qlearningagent import SkilledQLearningAgent
from visgrid.taxi.skills import skills5x5, skill_policy
from visgrid.sensors import IdentitySensor
total_timesteps = 50000
episode_timeout = 2000
epsilon = 0.1
random.seed(0)
env = Taxi5x5()
skill_names = list(skills5x5)
skill_fns = [(lambda n: (lambda: skill_policy(env, n)))(n) for n in skill_names]
skills = dict(zip(skill_names, skill_fns))
agent = SkilledQLearningAgent(options=skills, epsilon=epsilon)
sensor = IdentitySensor
timestep = 0
total_reward = 0
state = env.get_state()
observation = sensor.observe(state)
action = agent.act(observation, reward=0)
while timestep < total_timesteps:
for t in range(episode_timeout):
state, reward, done = env.step(action)
observation = sensor.observe(state)
timestep += 1
total_reward += reward
print(total_reward)
action = agent.act(observation, reward)
if done or timestep >= total_timesteps:
break
env.reset()
agent.end_of_episode()