-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat_dev] try to use independ_actor
- Loading branch information
1 parent
2aefa9d
commit 5568db5
Showing
10 changed files
with
112 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,12 @@ | |
Email: [email protected] | ||
Date: 2023-02-20 21:53:39 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-03 13:38:24 | ||
LastEditTime: 2024-06-11 23:34:11 | ||
Discription: | ||
''' | ||
class AlgoConfig(object): | ||
def __init__(self): | ||
self.independ_actor = True # whether to use independent actor | ||
self.independ_actor = False # whether to use independent actor | ||
# whether actor and critic share the same optimizer | ||
self.ppo_type = 'clip' # clip or kl | ||
self.eps_clip = 0.2 # clip parameter for PPO | ||
|
@@ -21,6 +21,7 @@ def __init__(self): | |
self.kl_beta = 1.5 # beta for KL penalty, 1.5 is the default value in the paper | ||
self.kl_alpha = 2 # alpha for KL penalty, 2 is the default value in the paper | ||
self.action_type_list = "continuous" # continuous action space | ||
self.return_form = 'mc' # 'mc' or 'td' or 'gae' | ||
self.gamma = 0.99 # discount factor | ||
self.k_epochs = 4 # update policy for K epochs | ||
self.lr = 0.0001 # for shared optimizer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-05-17 01:08:36 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-05 14:32:55 | ||
LastEditTime: 2024-06-11 19:59:16 | ||
Discription: | ||
''' | ||
import numpy as np | ||
|
@@ -18,7 +18,7 @@ def __init__(self, cfg): | |
self.gae_lambda = getattr(self.cfg, 'gae_lambda', 0.95) | ||
self.gamma = getattr(self.cfg, 'gamma', 0.95) | ||
self.batch_exps = [] | ||
|
||
def handle_exps_after_interact(self, exps): | ||
exp_len = self._get_exp_len(exps) | ||
next_value = exps[-1].value | ||
|
@@ -65,8 +65,14 @@ def _handle_exps_before_train(self, exps: list): | |
log_probs = [exp.log_prob.detach().cpu().numpy().item() for exp in exps] | ||
# log_probs = torch.cat(log_probs, dim=0).detach() # [batch_size,1] | ||
# log_probs = torch.tensor(log_probs, dtype = torch.float32, device = self.cfg.device).unsqueeze(dim=1) | ||
|
||
returns = np.array([exp.return_mc_normed for exp in exps]) | ||
if self.cfg.return_form.lower() == 'mc': | ||
returns = np.array([exp.return_mc_normed for exp in exps]) | ||
elif self.cfg.return_form.lower() == 'td': | ||
returns = np.array([exp.normed_return_td for exp in exps]) | ||
elif self.cfg.return_form.lower() == 'gae': | ||
returns = np.array([exp.normed_return_gae for exp in exps]) | ||
else: | ||
raise NotImplementedError("return_form not implemented") | ||
# returns = torch.tensor(returns, dtype = torch.float32, device = self.cfg.device).unsqueeze(dim=1) | ||
self.data_after_train.update({'log_probs': log_probs, 'returns': returns}) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-05 14:33:20 | ||
LastEditTime: 2024-06-11 23:38:05 | ||
Discription: | ||
''' | ||
import torch | ||
|
@@ -63,32 +63,34 @@ def create_model(self): | |
self.model = ActorCriticNetwork(self.cfg, self.state_size_list).to(self.device) | ||
|
||
def create_optimizer(self): | ||
self.optimizer = optim.Adam(self.model.parameters(), lr = self.cfg.lr) | ||
if getattr(self.cfg, 'independ_actor', False): | ||
self.optimizer = optim.Adam([{'params': self.model.actor.parameters(), 'lr': self.cfg.actor_lr}, | ||
{'params': self.model.critic.parameters(), 'lr': self.cfg.critic_lr}]) | ||
else: | ||
self.optimizer = optim.Adam(self.model.parameters(), lr = self.cfg.lr) | ||
|
||
def update_policy_transition(self): | ||
self.policy_transition = {'value': self.value.detach().cpu().numpy().item(), 'log_prob': self.log_prob} | ||
|
||
def sample_action(self, state, **kwargs): | ||
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32) | ||
# single state shape must be [batch_size, state_dim] | ||
if state.dim() == 1: | ||
state = state.unsqueeze(dim=0) | ||
if state.dim() == 1: state = state.unsqueeze(dim=0) | ||
model_outputs = self.model(state) | ||
self.value = model_outputs['value'] | ||
actor_outputs = model_outputs['actor_outputs'] | ||
actions, self.log_prob = self.model.action_layers.get_actions_and_log_probs(mode = 'sample', actor_outputs = actor_outputs) | ||
actions, self.log_prob = self.model.get_actions_and_log_probs(mode = 'sample', actor_outputs = actor_outputs) | ||
self.update_policy_transition() | ||
return actions | ||
|
||
@torch.no_grad() | ||
def predict_action(self, state, **kwargs): | ||
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32) | ||
# single state shape must be [batch_size, state_dim] | ||
if state.dim() == 1: | ||
state = state.unsqueeze(dim=0) | ||
if state.dim() == 1: state = state.unsqueeze(dim=0) | ||
model_outputs = self.model(state) | ||
actor_outputs = model_outputs['actor_outputs'] | ||
actions = self.model.action_layers.get_actions(mode = 'predict', actor_outputs = actor_outputs) | ||
actions = self.model.get_actions(mode = 'predict', actor_outputs = actor_outputs) | ||
return actions | ||
|
||
def prepare_data_before_learn(self, **kwargs): | ||
|
@@ -120,9 +122,9 @@ def learn(self, **kwargs): | |
model_outputs = self.model(old_states) | ||
values = model_outputs['value'] | ||
actor_outputs = model_outputs['actor_outputs'] | ||
new_log_probs = self.model.action_layers.get_log_probs_action(actor_outputs, old_actions) | ||
new_log_probs = self.model.get_log_probs_action(actor_outputs, old_actions) | ||
# new_log_probs = self.model.action_layers.get_log_probs_action(old_actions) | ||
entropy_mean = self.model.action_layers.get_mean_entropy(actor_outputs) | ||
entropy_mean = self.model.get_mean_entropy(actor_outputs) | ||
advantages = returns - values.detach() # shape:[batch_size,1] | ||
# get action probabilities | ||
# compute ratio (pi_theta / pi_theta__old): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:02:30 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-05 14:17:47 | ||
LastEditTime: 2024-06-11 20:12:47 | ||
Discription: | ||
''' | ||
import torch | ||
|
@@ -32,6 +32,14 @@ def _get_exp_len(self, exps, max_step: int = 1): | |
exp_len = exp_len - max_step | ||
return exp_len | ||
|
||
def get_training_data(self): | ||
''' get training data | ||
''' | ||
exps = self.buffer.sample(sequential=True) | ||
if exps is not None: | ||
self._handle_exps_before_train(exps) | ||
return self.data_after_train | ||
|
||
def handle_exps_after_interact(self, exps: list) -> list: | ||
''' handle exps after interact | ||
''' | ||
|
@@ -40,7 +48,6 @@ def handle_exps_after_interact(self, exps: list) -> list: | |
def add_exps(self, exps): | ||
exps = self.handle_exps_after_interact(exps) | ||
self.buffer.push(exps) | ||
|
||
|
||
def get_training_data(self): | ||
''' get training data | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-02 10:51:15 | ||
LastEditTime: 2024-06-11 23:41:41 | ||
Discription: | ||
''' | ||
import copy | ||
|
@@ -218,26 +218,7 @@ def reset_noise(self): | |
self.branch_layers.reset_noise() | ||
self.merge_layer.reset_noise() | ||
|
||
class ActorCriticNetwork(BaseNework): | ||
''' Value network, for policy-based methods, in which the branch_layers and critic share the same network | ||
''' | ||
def __init__(self, cfg: MergedConfig, input_size_list: list) -> None: | ||
super(ActorCriticNetwork, self).__init__(cfg, input_size_list) | ||
self.action_type_list = self.cfg.action_type_list | ||
self.create_graph() | ||
|
||
def create_graph(self): | ||
self.branch_layers = BranchLayers(self.cfg.branch_layers, self.input_size_list) | ||
self.merge_layer = MergeLayer(self.cfg.merge_layers, self.branch_layers.output_size_list) | ||
self.value_layer, _ = create_layer(self.merge_layer.output_size, LayerConfig(layer_type='linear', layer_size=[1], activation='none')) | ||
self.action_layers = ActionLayers(self.cfg, self.merge_layer.output_size,) | ||
|
||
def forward(self, x, pre_legal_actions=None): | ||
x = self.branch_layers(x) | ||
x = self.merge_layer(x) | ||
value = self.value_layer(x) | ||
actor_outputs = self.action_layers(x, pre_legal_actions = pre_legal_actions) | ||
return {'value': value, 'actor_outputs': actor_outputs} | ||
|
||
|
||
class ActorNetwork(BaseNework): | ||
def __init__(self, cfg: MergedConfig, input_size_list) -> None: | ||
|
@@ -272,48 +253,58 @@ def forward(self, x): | |
x = self.merge_layer(x) | ||
value = self.value_layer(x) | ||
return value | ||
|
||
if __name__ == "__main__": | ||
# test:export PYTHONPATH=./:$PYTHONPATH | ||
import torch | ||
from joyrl.framework.config import MergedConfig | ||
import gymnasium as gym | ||
cfg = MergedConfig() | ||
state_size = [[None, 4], [None, 4]] | ||
cfg.n_actions = 2 | ||
cfg.continuous = False | ||
cfg.min_policy = 0 | ||
cfg.branch_layers = [ | ||
|
||
[ | ||
{'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
{'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
], | ||
# [ | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# ], | ||
] | ||
cfg.merge_layers = [ | ||
{'layer_type': 'linear', 'layer_size': [2], 'activation': 'ReLU'}, | ||
{'layer_type': 'linear', 'layer_size': [2], 'activation': 'ReLU'}, | ||
] | ||
cfg.value_layers = [ | ||
{'layer_type': 'embed', 'n_embeddings': 10, 'embedding_dim': 32, 'activation': 'none'}, | ||
{'layer_type': 'Linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
{'layer_type': 'Linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
] | ||
cfg.actor_layers = [ | ||
{'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'}, | ||
{'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'}, | ||
] | ||
action_space = gym.spaces.Discrete(2) | ||
model = QNetwork(cfg, state_size, [action_space.n]) | ||
x = [torch.tensor([[ 0.0012, 0.0450, -0.0356, 0.0449]]), torch.tensor([[ 0.0012, 0.0450, -0.0356, 0.0449]])] | ||
x = model(x) | ||
print(x) | ||
# value_net = QNetwork(cfg, state_dim, cfg.n_actions) | ||
# print(value_net) | ||
# x = torch.tensor([36]) | ||
# print(x.shape) | ||
# print(value_net(x)) | ||
|
||
class ActorCriticNetwork(BaseNework): | ||
''' Value network, for policy-based methods, in which the branch_layers and critic share the same network | ||
''' | ||
def __init__(self, cfg: MergedConfig, input_size_list: list) -> None: | ||
super(ActorCriticNetwork, self).__init__(cfg, input_size_list) | ||
self.action_type_list = self.cfg.action_type_list | ||
self.create_graph() | ||
|
||
def create_graph(self): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
self.actor = ActorNetwork(self.cfg, self.input_size_list) | ||
self.critic = CriticNetwork(self.cfg, self.input_size_list) | ||
else: | ||
self.branch_layers = BranchLayers(self.cfg.branch_layers, self.input_size_list) | ||
self.merge_layer = MergeLayer(self.cfg.merge_layers, self.branch_layers.output_size_list) | ||
self.value_layer, _ = create_layer(self.merge_layer.output_size, LayerConfig(layer_type='linear', layer_size=[1], activation='none')) | ||
self.action_layers = ActionLayers(self.cfg, self.merge_layer.output_size,) | ||
|
||
def forward(self, x, pre_legal_actions=None): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
actor_outputs = self.actor(x, pre_legal_actions) | ||
value = self.critic(x) | ||
return {'value': value, 'actor_outputs': actor_outputs} | ||
else: | ||
x = self.branch_layers(x) | ||
x = self.merge_layer(x) | ||
value = self.value_layer(x) | ||
actor_outputs = self.action_layers(x, pre_legal_actions = pre_legal_actions) | ||
return {'value': value, 'actor_outputs': actor_outputs} | ||
|
||
def get_actions_and_log_probs(self, **kwargs): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
return self.actor.action_layers.get_actions_and_log_probs(**kwargs) | ||
else: | ||
return self.action_layers.get_actions_and_log_probs(**kwargs) | ||
|
||
def get_log_probs_action(self, actor_outputs, actions): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
return self.actor.action_layers.get_log_probs_action(actor_outputs, actions) | ||
else: | ||
return self.action_layers.get_log_probs_action(actor_outputs, actions) | ||
|
||
def get_mean_entropy(self, actor_outputs): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
return self.actor.action_layers.get_mean_entropy(actor_outputs) | ||
else: | ||
return self.action_layers.get_mean_entropy(actor_outputs) | ||
|
||
def get_actions(self, **kwargs): | ||
if getattr(self.cfg, 'independ_actor', False): | ||
return self.actor.action_layers.get_actions(**kwargs) | ||
else: | ||
return self.action_layers.get_actions(**kwargs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-06-11 13:46:02 | ||
LastEditTime: 2024-06-11 19:52:29 | ||
Discription: | ||
''' | ||
import time | ||
|
@@ -84,7 +84,7 @@ def _get_training_data(self): | |
break | ||
if self.cfg.is_learner_async: | ||
break | ||
if time.time() - get_training_data_time >= 0.05: | ||
if time.time() - get_training_data_time >= 0.02: | ||
# exec_method(self.logger, 'warning', 'remote', "[Collector._get_training_data] get training data timeout!") | ||
get_training_data_time = time.time() | ||
break | ||
|
32 changes: 0 additions & 32 deletions
32
presets/ClassControl/CartPole-v1/CartPole-v1_PPO-KL_Test.yaml
This file was deleted.
Oops, something went wrong.
32 changes: 0 additions & 32 deletions
32
presets/ClassControl/CartPole-v1/CartPole-v1_PPO-KL_Train.yaml
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.