Skip to content

Commit

Permalink
[feat_dev] try to use independ_actor
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jun 11, 2024
1 parent 2aefa9d commit 5568db5
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 161 deletions.
5 changes: 3 additions & 2 deletions joyrl/algos/PPO/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
Email: [email protected]
Date: 2023-02-20 21:53:39
LastEditor: JiangJi
LastEditTime: 2024-06-03 13:38:24
LastEditTime: 2024-06-11 23:34:11
Discription:
'''
class AlgoConfig(object):
def __init__(self):
self.independ_actor = True # whether to use independent actor
self.independ_actor = False # whether to use independent actor
# whether actor and critic share the same optimizer
self.ppo_type = 'clip' # clip or kl
self.eps_clip = 0.2 # clip parameter for PPO
Expand All @@ -21,6 +21,7 @@ def __init__(self):
self.kl_beta = 1.5 # beta for KL penalty, 1.5 is the default value in the paper
self.kl_alpha = 2 # alpha for KL penalty, 2 is the default value in the paper
self.action_type_list = "continuous" # continuous action space
self.return_form = 'mc' # 'mc' or 'td' or 'gae'
self.gamma = 0.99 # discount factor
self.k_epochs = 4 # update policy for K epochs
self.lr = 0.0001 # for shared optimizer
Expand Down
14 changes: 10 additions & 4 deletions joyrl/algos/PPO/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-05-17 01:08:36
LastEditor: JiangJi
LastEditTime: 2024-06-05 14:32:55
LastEditTime: 2024-06-11 19:59:16
Discription:
'''
import numpy as np
Expand All @@ -18,7 +18,7 @@ def __init__(self, cfg):
self.gae_lambda = getattr(self.cfg, 'gae_lambda', 0.95)
self.gamma = getattr(self.cfg, 'gamma', 0.95)
self.batch_exps = []

def handle_exps_after_interact(self, exps):
exp_len = self._get_exp_len(exps)
next_value = exps[-1].value
Expand Down Expand Up @@ -65,8 +65,14 @@ def _handle_exps_before_train(self, exps: list):
log_probs = [exp.log_prob.detach().cpu().numpy().item() for exp in exps]
# log_probs = torch.cat(log_probs, dim=0).detach() # [batch_size,1]
# log_probs = torch.tensor(log_probs, dtype = torch.float32, device = self.cfg.device).unsqueeze(dim=1)

returns = np.array([exp.return_mc_normed for exp in exps])
if self.cfg.return_form.lower() == 'mc':
returns = np.array([exp.return_mc_normed for exp in exps])
elif self.cfg.return_form.lower() == 'td':
returns = np.array([exp.normed_return_td for exp in exps])
elif self.cfg.return_form.lower() == 'gae':
returns = np.array([exp.normed_return_gae for exp in exps])
else:
raise NotImplementedError("return_form not implemented")
# returns = torch.tensor(returns, dtype = torch.float32, device = self.cfg.device).unsqueeze(dim=1)
self.data_after_train.update({'log_probs': log_probs, 'returns': returns})

Expand Down
22 changes: 12 additions & 10 deletions joyrl/algos/PPO/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-06-05 14:33:20
LastEditTime: 2024-06-11 23:38:05
Discription:
'''
import torch
Expand Down Expand Up @@ -63,32 +63,34 @@ def create_model(self):
self.model = ActorCriticNetwork(self.cfg, self.state_size_list).to(self.device)

def create_optimizer(self):
self.optimizer = optim.Adam(self.model.parameters(), lr = self.cfg.lr)
if getattr(self.cfg, 'independ_actor', False):
self.optimizer = optim.Adam([{'params': self.model.actor.parameters(), 'lr': self.cfg.actor_lr},
{'params': self.model.critic.parameters(), 'lr': self.cfg.critic_lr}])
else:
self.optimizer = optim.Adam(self.model.parameters(), lr = self.cfg.lr)

def update_policy_transition(self):
self.policy_transition = {'value': self.value.detach().cpu().numpy().item(), 'log_prob': self.log_prob}

def sample_action(self, state, **kwargs):
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32)
# single state shape must be [batch_size, state_dim]
if state.dim() == 1:
state = state.unsqueeze(dim=0)
if state.dim() == 1: state = state.unsqueeze(dim=0)
model_outputs = self.model(state)
self.value = model_outputs['value']
actor_outputs = model_outputs['actor_outputs']
actions, self.log_prob = self.model.action_layers.get_actions_and_log_probs(mode = 'sample', actor_outputs = actor_outputs)
actions, self.log_prob = self.model.get_actions_and_log_probs(mode = 'sample', actor_outputs = actor_outputs)
self.update_policy_transition()
return actions

@torch.no_grad()
def predict_action(self, state, **kwargs):
state = torch.tensor(np.array(state), device=self.device, dtype=torch.float32)
# single state shape must be [batch_size, state_dim]
if state.dim() == 1:
state = state.unsqueeze(dim=0)
if state.dim() == 1: state = state.unsqueeze(dim=0)
model_outputs = self.model(state)
actor_outputs = model_outputs['actor_outputs']
actions = self.model.action_layers.get_actions(mode = 'predict', actor_outputs = actor_outputs)
actions = self.model.get_actions(mode = 'predict', actor_outputs = actor_outputs)
return actions

def prepare_data_before_learn(self, **kwargs):
Expand Down Expand Up @@ -120,9 +122,9 @@ def learn(self, **kwargs):
model_outputs = self.model(old_states)
values = model_outputs['value']
actor_outputs = model_outputs['actor_outputs']
new_log_probs = self.model.action_layers.get_log_probs_action(actor_outputs, old_actions)
new_log_probs = self.model.get_log_probs_action(actor_outputs, old_actions)
# new_log_probs = self.model.action_layers.get_log_probs_action(old_actions)
entropy_mean = self.model.action_layers.get_mean_entropy(actor_outputs)
entropy_mean = self.model.get_mean_entropy(actor_outputs)
advantages = returns - values.detach() # shape:[batch_size,1]
# get action probabilities
# compute ratio (pi_theta / pi_theta__old):
Expand Down
11 changes: 9 additions & 2 deletions joyrl/algos/base/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-06-05 14:17:47
LastEditTime: 2024-06-11 20:12:47
Discription:
'''
import torch
Expand All @@ -32,6 +32,14 @@ def _get_exp_len(self, exps, max_step: int = 1):
exp_len = exp_len - max_step
return exp_len

def get_training_data(self):
''' get training data
'''
exps = self.buffer.sample(sequential=True)
if exps is not None:
self._handle_exps_before_train(exps)
return self.data_after_train

def handle_exps_after_interact(self, exps: list) -> list:
''' handle exps after interact
'''
Expand All @@ -40,7 +48,6 @@ def handle_exps_after_interact(self, exps: list) -> list:
def add_exps(self, exps):
exps = self.handle_exps_after_interact(exps)
self.buffer.push(exps)


def get_training_data(self):
''' get training data
Expand Down
123 changes: 57 additions & 66 deletions joyrl/algos/base/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-06-02 10:51:15
LastEditTime: 2024-06-11 23:41:41
Discription:
'''
import copy
Expand Down Expand Up @@ -218,26 +218,7 @@ def reset_noise(self):
self.branch_layers.reset_noise()
self.merge_layer.reset_noise()

class ActorCriticNetwork(BaseNework):
''' Value network, for policy-based methods, in which the branch_layers and critic share the same network
'''
def __init__(self, cfg: MergedConfig, input_size_list: list) -> None:
super(ActorCriticNetwork, self).__init__(cfg, input_size_list)
self.action_type_list = self.cfg.action_type_list
self.create_graph()

def create_graph(self):
self.branch_layers = BranchLayers(self.cfg.branch_layers, self.input_size_list)
self.merge_layer = MergeLayer(self.cfg.merge_layers, self.branch_layers.output_size_list)
self.value_layer, _ = create_layer(self.merge_layer.output_size, LayerConfig(layer_type='linear', layer_size=[1], activation='none'))
self.action_layers = ActionLayers(self.cfg, self.merge_layer.output_size,)

def forward(self, x, pre_legal_actions=None):
x = self.branch_layers(x)
x = self.merge_layer(x)
value = self.value_layer(x)
actor_outputs = self.action_layers(x, pre_legal_actions = pre_legal_actions)
return {'value': value, 'actor_outputs': actor_outputs}


class ActorNetwork(BaseNework):
def __init__(self, cfg: MergedConfig, input_size_list) -> None:
Expand Down Expand Up @@ -272,48 +253,58 @@ def forward(self, x):
x = self.merge_layer(x)
value = self.value_layer(x)
return value

if __name__ == "__main__":
# test:export PYTHONPATH=./:$PYTHONPATH
import torch
from joyrl.framework.config import MergedConfig
import gymnasium as gym
cfg = MergedConfig()
state_size = [[None, 4], [None, 4]]
cfg.n_actions = 2
cfg.continuous = False
cfg.min_policy = 0
cfg.branch_layers = [

[
{'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
{'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
],
# [
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# ],
]
cfg.merge_layers = [
{'layer_type': 'linear', 'layer_size': [2], 'activation': 'ReLU'},
{'layer_type': 'linear', 'layer_size': [2], 'activation': 'ReLU'},
]
cfg.value_layers = [
{'layer_type': 'embed', 'n_embeddings': 10, 'embedding_dim': 32, 'activation': 'none'},
{'layer_type': 'Linear', 'layer_size': [64], 'activation': 'ReLU'},
{'layer_type': 'Linear', 'layer_size': [64], 'activation': 'ReLU'},
]
cfg.actor_layers = [
{'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'},
{'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'},
]
action_space = gym.spaces.Discrete(2)
model = QNetwork(cfg, state_size, [action_space.n])
x = [torch.tensor([[ 0.0012, 0.0450, -0.0356, 0.0449]]), torch.tensor([[ 0.0012, 0.0450, -0.0356, 0.0449]])]
x = model(x)
print(x)
# value_net = QNetwork(cfg, state_dim, cfg.n_actions)
# print(value_net)
# x = torch.tensor([36])
# print(x.shape)
# print(value_net(x))

class ActorCriticNetwork(BaseNework):
''' Value network, for policy-based methods, in which the branch_layers and critic share the same network
'''
def __init__(self, cfg: MergedConfig, input_size_list: list) -> None:
super(ActorCriticNetwork, self).__init__(cfg, input_size_list)
self.action_type_list = self.cfg.action_type_list
self.create_graph()

def create_graph(self):
if getattr(self.cfg, 'independ_actor', False):
self.actor = ActorNetwork(self.cfg, self.input_size_list)
self.critic = CriticNetwork(self.cfg, self.input_size_list)
else:
self.branch_layers = BranchLayers(self.cfg.branch_layers, self.input_size_list)
self.merge_layer = MergeLayer(self.cfg.merge_layers, self.branch_layers.output_size_list)
self.value_layer, _ = create_layer(self.merge_layer.output_size, LayerConfig(layer_type='linear', layer_size=[1], activation='none'))
self.action_layers = ActionLayers(self.cfg, self.merge_layer.output_size,)

def forward(self, x, pre_legal_actions=None):
if getattr(self.cfg, 'independ_actor', False):
actor_outputs = self.actor(x, pre_legal_actions)
value = self.critic(x)
return {'value': value, 'actor_outputs': actor_outputs}
else:
x = self.branch_layers(x)
x = self.merge_layer(x)
value = self.value_layer(x)
actor_outputs = self.action_layers(x, pre_legal_actions = pre_legal_actions)
return {'value': value, 'actor_outputs': actor_outputs}

def get_actions_and_log_probs(self, **kwargs):
if getattr(self.cfg, 'independ_actor', False):
return self.actor.action_layers.get_actions_and_log_probs(**kwargs)
else:
return self.action_layers.get_actions_and_log_probs(**kwargs)

def get_log_probs_action(self, actor_outputs, actions):
if getattr(self.cfg, 'independ_actor', False):
return self.actor.action_layers.get_log_probs_action(actor_outputs, actions)
else:
return self.action_layers.get_log_probs_action(actor_outputs, actions)

def get_mean_entropy(self, actor_outputs):
if getattr(self.cfg, 'independ_actor', False):
return self.actor.action_layers.get_mean_entropy(actor_outputs)
else:
return self.action_layers.get_mean_entropy(actor_outputs)

def get_actions(self, **kwargs):
if getattr(self.cfg, 'independ_actor', False):
return self.actor.action_layers.get_actions(**kwargs)
else:
return self.action_layers.get_actions(**kwargs)

4 changes: 2 additions & 2 deletions joyrl/framework/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-06-11 13:46:02
LastEditTime: 2024-06-11 19:52:29
Discription:
'''
import time
Expand Down Expand Up @@ -84,7 +84,7 @@ def _get_training_data(self):
break
if self.cfg.is_learner_async:
break
if time.time() - get_training_data_time >= 0.05:
if time.time() - get_training_data_time >= 0.02:
# exec_method(self.logger, 'warning', 'remote', "[Collector._get_training_data] get training data timeout!")
get_training_data_time = time.time()
break
Expand Down
32 changes: 0 additions & 32 deletions presets/ClassControl/CartPole-v1/CartPole-v1_PPO-KL_Test.yaml

This file was deleted.

32 changes: 0 additions & 32 deletions presets/ClassControl/CartPole-v1/CartPole-v1_PPO-KL_Train.yaml

This file was deleted.

Loading

0 comments on commit 5568db5

Please sign in to comment.