-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[0.4.6.1] fix and add custom env example
- Loading branch information
1 parent
431a5e4
commit acfa10a
Showing
9 changed files
with
182 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import gymnasium as gym | ||
from gymnasium import Env, spaces | ||
from gymnasium.envs.toy_text.utils import categorical_sample | ||
import numpy as np | ||
from typing import Optional | ||
import joyrl | ||
|
||
UP = 0 | ||
RIGHT = 1 | ||
DOWN = 2 | ||
LEFT = 3 | ||
|
||
def one_hot(x, n): | ||
"""Create a one-hot encoding of x for n classes. | ||
Args: | ||
x: An integer. | ||
n: Number of classes. | ||
Returns: | ||
A numpy array of shape (n,). | ||
""" | ||
return np.identity(n)[x:x + 1].squeeze() | ||
class CustomCliffWalkingEnv(Env): | ||
|
||
def __init__(self): | ||
self.shape = (4, 12) | ||
self.start_state_index = np.ravel_multi_index((3, 0), self.shape) | ||
|
||
self.nS = np.prod(self.shape) | ||
self.nA = 4 | ||
|
||
# Cliff Location | ||
self._cliff = np.zeros(self.shape, dtype=bool) | ||
self._cliff[3, 1:-1] = True | ||
|
||
# Calculate transition probabilities and rewards | ||
self.P = {} | ||
for s in range(self.nS): | ||
position = np.unravel_index(s, self.shape) | ||
self.P[s] = {a: [] for a in range(self.nA)} | ||
self.P[s][UP] = self._calculate_transition_prob(position, [-1, 0]) | ||
self.P[s][RIGHT] = self._calculate_transition_prob(position, [0, 1]) | ||
self.P[s][DOWN] = self._calculate_transition_prob(position, [1, 0]) | ||
self.P[s][LEFT] = self._calculate_transition_prob(position, [0, -1]) | ||
|
||
# Calculate initial state distribution | ||
# We always start in state (3, 0) | ||
self.initial_state_distrib = np.zeros(self.nS) | ||
self.initial_state_distrib[self.start_state_index] = 1.0 | ||
|
||
self.observation_space = spaces.Discrete(self.nS) | ||
self.action_space = spaces.Discrete(self.nA) | ||
|
||
# pygame utils | ||
self.cell_size = (60, 60) | ||
self.window_size = ( | ||
self.shape[1] * self.cell_size[1], | ||
self.shape[0] * self.cell_size[0], | ||
) | ||
self.window_surface = None | ||
self.clock = None | ||
self.elf_images = None | ||
self.start_img = None | ||
self.goal_img = None | ||
self.cliff_img = None | ||
self.mountain_bg_img = None | ||
self.near_cliff_img = None | ||
self.tree_img = None | ||
|
||
def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray: | ||
"""Prevent the agent from falling out of the grid world.""" | ||
coord[0] = min(coord[0], self.shape[0] - 1) | ||
coord[0] = max(coord[0], 0) | ||
coord[1] = min(coord[1], self.shape[1] - 1) | ||
coord[1] = max(coord[1], 0) | ||
return coord | ||
|
||
def _calculate_transition_prob(self, current, delta): | ||
new_position = np.array(current) + np.array(delta) | ||
new_position = self._limit_coordinates(new_position).astype(int) | ||
new_state = np.ravel_multi_index(tuple(new_position), self.shape) | ||
if self._cliff[tuple(new_position)]: | ||
return [(1.0, self.start_state_index, -100, False)] | ||
|
||
terminal_state = (self.shape[0] - 1, self.shape[1] - 1) | ||
# use euclidean distance as reward, instead of constant -1 | ||
dist = np.linalg.norm(np.array(new_position) - np.array(terminal_state)) | ||
is_terminated = tuple(new_position) == terminal_state | ||
if is_terminated: | ||
return [(1.0, new_state, 10, is_terminated)] | ||
return [(1.0, new_state, -dist, is_terminated)] | ||
|
||
def step(self, a): | ||
transitions = self.P[self.s][a] | ||
i = categorical_sample([t[0] for t in transitions], self.np_random) | ||
p, s, r, t = transitions[i] | ||
self.s = s | ||
self.lastaction = a | ||
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make` | ||
return int(s), r, t, False, {"prob": p} | ||
return one_hot(int(s), self.nS), r, t, False, {"prob": p} # if you want to use one-hot state representation | ||
|
||
|
||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): | ||
super().reset(seed=seed) | ||
self.s = categorical_sample(self.initial_state_distrib, self.np_random) | ||
self.lastaction = None | ||
return int(self.s),{"prob": 1.0} | ||
return one_hot(int(self.s), self.nS), {"prob": 1.0} | ||
|
||
if __name__ == "__main__": | ||
my_env = CustomCliffWalkingEnv() | ||
yaml_path = "../presets/ToyText/CliffWalking-v0/CustomCliffWalking-v0_DQN.yaml" | ||
joyrl.run(yaml_path = yaml_path, env = my_env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-24 20:37:23 | ||
LastEditTime: 2023-12-24 22:28:03 | ||
Discription: | ||
''' | ||
from joyrl import algos, framework, envs, utils | ||
from joyrl.run import run | ||
|
||
__version__ = "0.4.6" | ||
__version__ = "0.4.6.1" | ||
|
||
__all__ = [ | ||
"algos", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,15 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-24 21:43:29 | ||
Discription: | ||
''' | ||
import torch.nn as nn | ||
from joyrl.algos.base.base_layers import create_layer, LayerConfig | ||
from joyrl.algos.base.base_layer import create_layer, LayerConfig | ||
from joyrl.algos.base.action_layers import ActionLayerType, DiscreteActionLayer, ContinuousActionLayer, DPGActionLayer | ||
|
||
class BaseNework(nn.Module): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:02:30 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-20 23:39:13 | ||
LastEditTime: 2023-12-24 21:36:10 | ||
Discription: | ||
''' | ||
import time | ||
|
@@ -77,7 +77,8 @@ def run(self): | |
s_t = time.time() | ||
while True: | ||
self.interactor_mgr.run() | ||
self.learner_mgr.run() | ||
if self.cfg.mode.lower() == 'train': | ||
self.learner_mgr.run() | ||
if self.tracker.pub_msg(Msg(type = MsgType.TRACKER_CHECK_TASK_END)): | ||
e_t = time.time() | ||
self.logger.info(f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 13:16:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-24 17:44:02 | ||
LastEditTime: 2023-12-24 22:27:05 | ||
Discription: | ||
''' | ||
import sys,os | ||
|
@@ -149,9 +149,12 @@ def config_dir(dir,name = None): | |
def env_config(self): | ||
''' create single env | ||
''' | ||
env_cfg_dic = self.env_cfg.__dict__ | ||
kwargs = {k: v for k, v in env_cfg_dic.items() if k not in env_cfg_dic['ignore_params']} | ||
env = gym.make(**kwargs) | ||
if self.custom_env is not None: | ||
env = self.custom_env | ||
else: | ||
env_cfg_dic = self.env_cfg.__dict__ | ||
kwargs = {k: v for k, v in env_cfg_dic.items() if k not in env_cfg_dic['ignore_params']} | ||
env = gym.make(**kwargs) | ||
setattr(self.cfg, 'obs_space', env.observation_space) | ||
setattr(self.cfg, 'action_space', env.action_space) | ||
if self.env_cfg.wrapper is not None: | ||
|
@@ -254,5 +257,5 @@ def run(**kwargs): | |
launcher.run() | ||
|
||
if __name__ == "__main__": | ||
launcher = Launcher() | ||
launcher = Launcher(**kwargs) | ||
launcher.run() |
40 changes: 40 additions & 0 deletions
40
presets/ToyText/CliffWalking-v0/CustomCliffWalking-v0_DQN.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
general_cfg: | ||
algo_name: DQN | ||
env_name: gym | ||
device: cpu | ||
mode: train | ||
collect_traj: false | ||
n_interactors: 1 | ||
load_checkpoint: false | ||
load_path: Train_single_CartPole-v1_DQN_20230515-211721 | ||
load_model_step: best | ||
max_episode: -1 | ||
max_step: 20 | ||
seed: 1 | ||
online_eval: true | ||
online_eval_episode: 10 | ||
model_save_fre: 500 | ||
|
||
algo_cfg: | ||
value_layers: | ||
- layer_type: embed | ||
n_embeddings: 48 | ||
embedding_dim: 4 | ||
- layer_type: linear | ||
layer_size: [256] | ||
activation: relu | ||
- layer_type: linear | ||
layer_size: [256] | ||
activation: relu | ||
batch_size: 128 | ||
buffer_type: REPLAY_QUE | ||
buffer_size: 10000 | ||
epsilon_decay: 1000 | ||
epsilon_end: 0.01 | ||
epsilon_start: 0.99 | ||
gamma: 0.95 | ||
lr: 0.001 | ||
target_update: 4 | ||
env_cfg: | ||
id: CustomCliffWalking-v0 | ||
render_mode: null |