Skip to content

Commit

Permalink
[feat_dev] update DQN in novel framwork
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jun 1, 2024
1 parent 88f0a2d commit c824261
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 38 deletions.
28 changes: 15 additions & 13 deletions joyrl/algos/DQN/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2024-01-25 09:58:33
LastEditor: JiangJi
LastEditTime: 2024-06-01 17:45:42
LastEditTime: 2024-06-02 00:19:46
Discription:
'''
import torch
Expand Down Expand Up @@ -51,25 +51,27 @@ def predict_action(self,state, **kwargs):
''' predict action
'''
state = [torch.tensor(np.array(state), device=self.device, dtype=torch.float32).unsqueeze(dim=0)]
_ = self.model(state)
actions = self.model.action_layers.get_actions()
model_outputs = self.model(state)
actor_outputs = model_outputs['actor_outputs']
actions = self.model.action_layers.get_actions(mode = 'predict', actor_outputs = actor_outputs)
return actions

def learn(self, **kwargs):
''' learn policy
'''
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones')
# compute current Q values
_ = self.model(states)
q_values = self.model.action_layers.get_qvalues()
actual_qvalues = q_values.gather(1, actions.long())
# compute next max q value
_ = self.target_model(next_states)
next_q_values_max = self.target_model.action_layers.get_qvalues().max(1)[0].unsqueeze(dim=1)
# compute target Q values
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values_max
# compute loss
self.loss = nn.MSELoss()(actual_qvalues, target_q_values)
self.loss = 0
actor_outputs = self.model(states)['actor_outputs']
target_actor_outputs = self.target_model(next_states)['actor_outputs']
for i in range(len(self.action_size_list)):
actual_q_value = actor_outputs[i]['q_value'].gather(1, actions[i].long())
# compute next max q value
next_q_value_max = target_actor_outputs[i]['q_value'].max(1)[0].unsqueeze(dim=1)
# compute target Q values
target_q_value = rewards + (1 - dones) * self.gamma * next_q_value_max
# compute loss
self.loss += nn.MSELoss()(actual_q_value, target_q_value)
self.optimizer.zero_grad()
self.loss.backward()
# clip to avoid gradient explosion
Expand Down
27 changes: 7 additions & 20 deletions joyrl/algos/base/action_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-25 09:28:26
LastEditor: JiangJi
LastEditTime: 2024-06-01 18:23:12
LastEditTime: 2024-06-02 00:17:16
Discription:
'''
from enum import Enum
Expand Down Expand Up @@ -63,28 +63,15 @@ def forward(self, x, **kwargs):
else:
q_value = self.action_value_layer(x)
output = {"q_value": q_value}
self.q_value = q_value
return output

def get_qvalue(self):
return self.q_value

def get_action(self, **kwargs):
mode = kwargs.get("mode", "sample")
if mode == "sample":
return self.sample_action()
elif mode == "predict":
return self.predict_action()
else:
raise NotImplementedError

def sample_action(self):
return torch.argmax(self.q_value).detach().cpu().numpy().item()
def sample_action(self, **kwargs):
q_value = kwargs.get("q_value", None)
return {"action": torch.argmax(q_value).detach().cpu().numpy().item()}

def predict_action(self):
''' get action
'''
return torch.argmax(self.q_value).detach().cpu().numpy().item()
def predict_action(self, **kwargs):
q_value = kwargs.get("q_value", None)
return {"action": torch.argmax(q_value).detach().cpu().numpy().item()}

class DiscreteActionLayer(BaseActionLayer):
def __init__(self, cfg, input_size, action_dim, id = 0, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions joyrl/algos/base/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-06-01 18:30:15
LastEditTime: 2024-06-02 00:11:14
Discription:
'''
import copy
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_qvalues(self):
qvalues = []
for _, action_layer in enumerate(self.action_layers):
qvalues.append(action_layer.get_qvalue())
return qvalues[0]
return qvalues

def get_actions(self, **kwargs):
mode = kwargs.get('mode', 'train')
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(self, x):
x = self.branch_layers(x)
x = self.merge_layer(x)
actor_outputs = self.action_layers(x)
return actor_outputs
return {"actor_outputs": actor_outputs}

def reset_noise(self):
''' reset noise for noisy layers
Expand Down
1 change: 1 addition & 0 deletions joyrl/algos/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, cfg : MergedConfig) -> None:
self.get_action_size()
self.create_model()
self.create_optimizer()
self.create_summary()

def get_state_size(self):
''' get state size
Expand Down
5 changes: 3 additions & 2 deletions presets/ClassControl/CartPole-v1/CartPole-v1_DQN.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
general_cfg:
joyrl_version: 0.5.0
joyrl_version: 0.6.0.2
algo_name: DQN
env_name: gym
device: cpu
mode: train
is_learner_async: false
collect_traj: false
n_interactors: 1
load_checkpoint: false
Expand All @@ -15,7 +16,7 @@ general_cfg:
online_eval: true
online_eval_episode: 10
model_save_fre: 500
policy_summary_fre: 2
policy_summary_fre: 10

algo_cfg:
learn_frequency: 1
Expand Down

0 comments on commit c824261

Please sign in to comment.