-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat_dev] update DQN in novel framwork
- Loading branch information
1 parent
88f0a2d
commit c824261
Showing
5 changed files
with
29 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2024-01-25 09:58:33 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-01 17:45:42 | ||
LastEditTime: 2024-06-02 00:19:46 | ||
Discription: | ||
''' | ||
import torch | ||
|
@@ -51,25 +51,27 @@ def predict_action(self,state, **kwargs): | |
''' predict action | ||
''' | ||
state = [torch.tensor(np.array(state), device=self.device, dtype=torch.float32).unsqueeze(dim=0)] | ||
_ = self.model(state) | ||
actions = self.model.action_layers.get_actions() | ||
model_outputs = self.model(state) | ||
actor_outputs = model_outputs['actor_outputs'] | ||
actions = self.model.action_layers.get_actions(mode = 'predict', actor_outputs = actor_outputs) | ||
return actions | ||
|
||
def learn(self, **kwargs): | ||
''' learn policy | ||
''' | ||
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones') | ||
# compute current Q values | ||
_ = self.model(states) | ||
q_values = self.model.action_layers.get_qvalues() | ||
actual_qvalues = q_values.gather(1, actions.long()) | ||
# compute next max q value | ||
_ = self.target_model(next_states) | ||
next_q_values_max = self.target_model.action_layers.get_qvalues().max(1)[0].unsqueeze(dim=1) | ||
# compute target Q values | ||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values_max | ||
# compute loss | ||
self.loss = nn.MSELoss()(actual_qvalues, target_q_values) | ||
self.loss = 0 | ||
actor_outputs = self.model(states)['actor_outputs'] | ||
target_actor_outputs = self.target_model(next_states)['actor_outputs'] | ||
for i in range(len(self.action_size_list)): | ||
actual_q_value = actor_outputs[i]['q_value'].gather(1, actions[i].long()) | ||
# compute next max q value | ||
next_q_value_max = target_actor_outputs[i]['q_value'].max(1)[0].unsqueeze(dim=1) | ||
# compute target Q values | ||
target_q_value = rewards + (1 - dones) * self.gamma * next_q_value_max | ||
# compute loss | ||
self.loss += nn.MSELoss()(actual_q_value, target_q_value) | ||
self.optimizer.zero_grad() | ||
self.loss.backward() | ||
# clip to avoid gradient explosion | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-25 09:28:26 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-01 18:23:12 | ||
LastEditTime: 2024-06-02 00:17:16 | ||
Discription: | ||
''' | ||
from enum import Enum | ||
|
@@ -63,28 +63,15 @@ def forward(self, x, **kwargs): | |
else: | ||
q_value = self.action_value_layer(x) | ||
output = {"q_value": q_value} | ||
self.q_value = q_value | ||
return output | ||
|
||
def get_qvalue(self): | ||
return self.q_value | ||
|
||
def get_action(self, **kwargs): | ||
mode = kwargs.get("mode", "sample") | ||
if mode == "sample": | ||
return self.sample_action() | ||
elif mode == "predict": | ||
return self.predict_action() | ||
else: | ||
raise NotImplementedError | ||
|
||
def sample_action(self): | ||
return torch.argmax(self.q_value).detach().cpu().numpy().item() | ||
def sample_action(self, **kwargs): | ||
q_value = kwargs.get("q_value", None) | ||
return {"action": torch.argmax(q_value).detach().cpu().numpy().item()} | ||
|
||
def predict_action(self): | ||
''' get action | ||
''' | ||
return torch.argmax(self.q_value).detach().cpu().numpy().item() | ||
def predict_action(self, **kwargs): | ||
q_value = kwargs.get("q_value", None) | ||
return {"action": torch.argmax(q_value).detach().cpu().numpy().item()} | ||
|
||
class DiscreteActionLayer(BaseActionLayer): | ||
def __init__(self, cfg, input_size, action_dim, id = 0, **kwargs): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-01 18:30:15 | ||
LastEditTime: 2024-06-02 00:11:14 | ||
Discription: | ||
''' | ||
import copy | ||
|
@@ -128,7 +128,7 @@ def get_qvalues(self): | |
qvalues = [] | ||
for _, action_layer in enumerate(self.action_layers): | ||
qvalues.append(action_layer.get_qvalue()) | ||
return qvalues[0] | ||
return qvalues | ||
|
||
def get_actions(self, **kwargs): | ||
mode = kwargs.get('mode', 'train') | ||
|
@@ -209,7 +209,7 @@ def forward(self, x): | |
x = self.branch_layers(x) | ||
x = self.merge_layer(x) | ||
actor_outputs = self.action_layers(x) | ||
return actor_outputs | ||
return {"actor_outputs": actor_outputs} | ||
|
||
def reset_noise(self): | ||
''' reset noise for noisy layers | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters