Skip to content

Commit

Permalink
[0.4.6.1] fix and add custom env example
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Dec 24, 2023
1 parent 431a5e4 commit acfa10a
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 13 deletions.
115 changes: 115 additions & 0 deletions examples/custom_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import gymnasium as gym
from gymnasium import Env, spaces
from gymnasium.envs.toy_text.utils import categorical_sample
import numpy as np
from typing import Optional
import joyrl

UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3

def one_hot(x, n):
"""Create a one-hot encoding of x for n classes.
Args:
x: An integer.
n: Number of classes.
Returns:
A numpy array of shape (n,).
"""
return np.identity(n)[x:x + 1].squeeze()
class CustomCliffWalkingEnv(Env):

def __init__(self):
self.shape = (4, 12)
self.start_state_index = np.ravel_multi_index((3, 0), self.shape)

self.nS = np.prod(self.shape)
self.nA = 4

# Cliff Location
self._cliff = np.zeros(self.shape, dtype=bool)
self._cliff[3, 1:-1] = True

# Calculate transition probabilities and rewards
self.P = {}
for s in range(self.nS):
position = np.unravel_index(s, self.shape)
self.P[s] = {a: [] for a in range(self.nA)}
self.P[s][UP] = self._calculate_transition_prob(position, [-1, 0])
self.P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1])
self.P[s][DOWN] = self._calculate_transition_prob(position, [1, 0])
self.P[s][LEFT] = self._calculate_transition_prob(position, [0, -1])

# Calculate initial state distribution
# We always start in state (3, 0)
self.initial_state_distrib = np.zeros(self.nS)
self.initial_state_distrib[self.start_state_index] = 1.0

self.observation_space = spaces.Discrete(self.nS)
self.action_space = spaces.Discrete(self.nA)

# pygame utils
self.cell_size = (60, 60)
self.window_size = (
self.shape[1] * self.cell_size[1],
self.shape[0] * self.cell_size[0],
)
self.window_surface = None
self.clock = None
self.elf_images = None
self.start_img = None
self.goal_img = None
self.cliff_img = None
self.mountain_bg_img = None
self.near_cliff_img = None
self.tree_img = None

def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray:
"""Prevent the agent from falling out of the grid world."""
coord[0] = min(coord[0], self.shape[0] - 1)
coord[0] = max(coord[0], 0)
coord[1] = min(coord[1], self.shape[1] - 1)
coord[1] = max(coord[1], 0)
return coord

def _calculate_transition_prob(self, current, delta):
new_position = np.array(current) + np.array(delta)
new_position = self._limit_coordinates(new_position).astype(int)
new_state = np.ravel_multi_index(tuple(new_position), self.shape)
if self._cliff[tuple(new_position)]:
return [(1.0, self.start_state_index, -100, False)]

terminal_state = (self.shape[0] - 1, self.shape[1] - 1)
# use euclidean distance as reward, instead of constant -1
dist = np.linalg.norm(np.array(new_position) - np.array(terminal_state))
is_terminated = tuple(new_position) == terminal_state
if is_terminated:
return [(1.0, new_state, 10, is_terminated)]
return [(1.0, new_state, -dist, is_terminated)]

def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, t = transitions[i]
self.s = s
self.lastaction = a
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
return int(s), r, t, False, {"prob": p}
return one_hot(int(s), self.nS), r, t, False, {"prob": p} # if you want to use one-hot state representation


def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
return int(self.s),{"prob": 1.0}
return one_hot(int(self.s), self.nS), {"prob": 1.0}

if __name__ == "__main__":
my_env = CustomCliffWalkingEnv()
yaml_path = "../presets/ToyText/CliffWalking-v0/CustomCliffWalking-v0_DQN.yaml"
joyrl.run(yaml_path = yaml_path, env = my_env)
4 changes: 2 additions & 2 deletions joyrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2023-12-24 20:37:23
LastEditTime: 2023-12-24 22:28:03
Discription:
'''
from joyrl import algos, framework, envs, utils
from joyrl.run import run

__version__ = "0.4.6"
__version__ = "0.4.6.1"

__all__ = [
"algos",
Expand Down
4 changes: 2 additions & 2 deletions joyrl/algos/base/action_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from joyrl.algos.base.base_layers import LayerConfig
from joyrl.algos.base.base_layers import create_layer
from joyrl.algos.base.base_layer import LayerConfig
from joyrl.algos.base.base_layer import create_layer

class ActionLayerType(Enum):
''' Action layer type
Expand Down
File renamed without changes.
12 changes: 11 additions & 1 deletion joyrl/algos/base/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2023-12-24 21:43:29
Discription:
'''
import torch.nn as nn
from joyrl.algos.base.base_layers import create_layer, LayerConfig
from joyrl.algos.base.base_layer import create_layer, LayerConfig
from joyrl.algos.base.action_layers import ActionLayerType, DiscreteActionLayer, ContinuousActionLayer, DPGActionLayer

class BaseNework(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion joyrl/framework/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
global_episode = self.tracker.pub_msg(Msg(MsgType.TRACKER_GET_EPISODE))
self.tracker.pub_msg(Msg(MsgType.TRACKER_INCREASE_EPISODE))
if global_episode % self.cfg.interact_summary_fre == 0:
self.logger.info(f"Interactor {self.id} finished episode {global_episode} with reward {self.ep_reward:.3f} in {self.ep_step} steps")
self.logger.info(f"Interactor {self.id} finished episode {global_episode} with reward {self.ep_reward:.3f} in {self.ep_step} steps, truncated: {truncated}, terminated: {terminated}")
# put summary to recorder
interact_summary = {'reward':self.ep_reward,'step':self.ep_step}
self.summary.append((global_episode, interact_summary))
Expand Down
5 changes: 3 additions & 2 deletions joyrl/framework/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2023-12-20 23:39:13
LastEditTime: 2023-12-24 21:36:10
Discription:
'''
import time
Expand Down Expand Up @@ -77,7 +77,8 @@ def run(self):
s_t = time.time()
while True:
self.interactor_mgr.run()
self.learner_mgr.run()
if self.cfg.mode.lower() == 'train':
self.learner_mgr.run()
if self.tracker.pub_msg(Msg(type = MsgType.TRACKER_CHECK_TASK_END)):
e_t = time.time()
self.logger.info(f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s")
Expand Down
13 changes: 8 additions & 5 deletions offline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 13:16:59
LastEditor: JiangJi
LastEditTime: 2023-12-24 17:44:02
LastEditTime: 2023-12-24 22:27:05
Discription:
'''
import sys,os
Expand Down Expand Up @@ -149,9 +149,12 @@ def config_dir(dir,name = None):
def env_config(self):
''' create single env
'''
env_cfg_dic = self.env_cfg.__dict__
kwargs = {k: v for k, v in env_cfg_dic.items() if k not in env_cfg_dic['ignore_params']}
env = gym.make(**kwargs)
if self.custom_env is not None:
env = self.custom_env
else:
env_cfg_dic = self.env_cfg.__dict__
kwargs = {k: v for k, v in env_cfg_dic.items() if k not in env_cfg_dic['ignore_params']}
env = gym.make(**kwargs)
setattr(self.cfg, 'obs_space', env.observation_space)
setattr(self.cfg, 'action_space', env.action_space)
if self.env_cfg.wrapper is not None:
Expand Down Expand Up @@ -254,5 +257,5 @@ def run(**kwargs):
launcher.run()

if __name__ == "__main__":
launcher = Launcher()
launcher = Launcher(**kwargs)
launcher.run()
40 changes: 40 additions & 0 deletions presets/ToyText/CliffWalking-v0/CustomCliffWalking-v0_DQN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
general_cfg:
algo_name: DQN
env_name: gym
device: cpu
mode: train
collect_traj: false
n_interactors: 1
load_checkpoint: false
load_path: Train_single_CartPole-v1_DQN_20230515-211721
load_model_step: best
max_episode: -1
max_step: 20
seed: 1
online_eval: true
online_eval_episode: 10
model_save_fre: 500

algo_cfg:
value_layers:
- layer_type: embed
n_embeddings: 48
embedding_dim: 4
- layer_type: linear
layer_size: [256]
activation: relu
- layer_type: linear
layer_size: [256]
activation: relu
batch_size: 128
buffer_type: REPLAY_QUE
buffer_size: 10000
epsilon_decay: 1000
epsilon_end: 0.01
epsilon_start: 0.99
gamma: 0.95
lr: 0.001
target_update: 4
env_cfg:
id: CustomCliffWalking-v0
render_mode: null

0 comments on commit acfa10a

Please sign in to comment.