Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ppo_mujocov2 #1107

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .teamcity/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ function run_example_test {
python -m pip uninstall -r ./examples/DQN_variant/requirements.txt -y

python -m pip install -r ./examples/PPO/requirements_atari.txt
python examples/PPO/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
python examples/PPO/atari/train.py --train_total_steps 5000 --env PongNoFrameskip-v4
python -m pip uninstall -r ./examples/PPO/requirements_atari.txt -y

xparl start --port 8010 --cpu_num 8
python -m pip install -r ./examples/PPO/requirements_mujoco.txt
python examples/PPO/train.py --train_total_steps 5000 --env HalfCheetah-v4 --continuous_action
python examples/PPO/mujoco/train.py --env 'HalfCheetah-v2' --train_total_episodes 100 --env_num 5
python -m pip uninstall -r ./examples/PPO/requirements_mujoco.txt -y
xparl stop

python -m pip install -r ./examples/SAC/requirements.txt
python examples/SAC/train.py --train_total_steps 5000 --env HalfCheetah-v4
Expand Down
47 changes: 26 additions & 21 deletions examples/PPO/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ Based on PARL, the PPO algorithm of deep reinforcement learning has been reprodu
> Paper: PPO in [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)

### Mujoco/Atari games introduction
PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install mujoco-py and get license. For more details, please visit [Mujoco](https://github.com/deepmind/mujoco).
PARL currently supports the open-source version of Mujoco provided by DeepMind, so users do not need to download binaries of Mujoco as well as install [mujoco-py](https://github.com/openai/mujoco-py#install-mujoco). For more details, please visit [Mujoco](https://github.com/deepmind/mujoco).

### Benchmark result
#### 1. Mujoco games results
The horizontal axis represents the number of episodes.
<p align="center">
<img src="https://github.com/benchmarking-rl/PARL-experiments/blob/master/PPO/paddle/mujoco_result.png" alt="mujoco-result"/>
</p>

#### 2. Atari games results
The horizontal axis represents the number of steps.
<p align="center">
<img src="https://github.com/benchmarking-rl/PARL-experiments/blob/master/PPO/paddle/atari_result.png" alt="atari-result"/>
</p>
Expand All @@ -23,29 +25,21 @@ PARL currently supports the open-source version of Mujoco provided by DeepMind,
### Mujoco-Dependencies:
+ python3.7+
ShuaibinLi marked this conversation as resolved.
Show resolved Hide resolved
+ [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle)
+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL)
+ gym>=0.26.0
+ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL)
+ gym==0.18.0
+ mujoco>=2.2.2
+ mujoco-py==2.1.2.14

### Atari-Dependencies:
+ [paddle>=2.3.1](https://github.com/PaddlePaddle/Paddle)
+ [parl>=2.1.1](https://github.com/PaddlePaddle/PARL)
+ [parl>=2.2.2](https://github.com/PaddlePaddle/PARL)
+ gym==0.18.0
+ atari-py==0.2.6
+ opencv-python

### Training:

```
# To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default)
python train.py

# To train an agent for continuous action game (Mujoco)
python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000
```

### Distributed Training
Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slow.
### Training Mujoco Distributedly
Accelerate training process by setting `xparl_addr` and `env_num > 1` when environment simulation running very slowly.
At first, we can start a local cluster with 8 CPUs:

```
Expand All @@ -56,14 +50,25 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html).

Then we can start the distributed training by running:
Then we can start the distributed training for mujoco games by running:

```
# To train an agent distributedly
cd mujoco

# for discrete action game (Atari games)
python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
python train.py --env 'HalfCheetah-v2' --train_total_episodes 100000 --env_num 5
```

# for continuous action game (Mujoco games)
python train.py --env 'HalfCheetah-v4' --continuous_action --train_total_steps 1000000 --env_num 5 --xparl_addr 'localhost:8010'

### Training Atari
To train an agent for discrete action game (Atari: PongNoFrameskip-v4 by default):

```
cd atari

# Local training
python train.py
ShuaibinLi marked this conversation as resolved.
Show resolved Hide resolved

# Distributed training
xparl start --port 8010 --cpu_num 8
python train.py --env "PongNoFrameskip-v4" --env_num 8 --xparl_addr 'localhost:8010'
```
15 changes: 6 additions & 9 deletions examples/PPO/agent.py → examples/PPO/atari/atari_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from parl.utils.scheduler import LinearDecayScheduler


class PPOAgent(parl.Agent):
class AtariAgent(parl.Agent):
""" Agent of PPO env

Args:
Expand All @@ -27,12 +27,11 @@ class PPOAgent(parl.Agent):
"""

def __init__(self, algorithm, config):
super(PPOAgent, self).__init__(algorithm)
super(AtariAgent, self).__init__(algorithm)

self.config = config
if self.config['lr_decay']:
self.lr_scheduler = LinearDecayScheduler(
self.config['initial_lr'], self.config['num_updates'])
self.lr_scheduler = LinearDecayScheduler(self.config['initial_lr'], self.config['num_updates'])

def predict(self, obs):
""" Predict action from current policy given observation
Expand Down Expand Up @@ -85,8 +84,7 @@ def learn(self, rollout):
else:
lr = None

minibatch_size = int(
self.config['batch_size'] // self.config['num_minibatches'])
minibatch_size = int(self.config['batch_size'] // self.config['num_minibatches'])

indexes = np.arange(self.config['batch_size'])
for epoch in range(self.config['update_epochs']):
Expand All @@ -105,9 +103,8 @@ def learn(self, rollout):
batch_return = paddle.to_tensor(batch_return)
batch_value = paddle.to_tensor(batch_value)

value_loss, action_loss, entropy_loss = self.alg.learn(
batch_obs, batch_action, batch_value, batch_return,
batch_logprob, batch_adv, lr)
value_loss, action_loss, entropy_loss = self.alg.learn(batch_obs, batch_action, batch_value,
batch_return, batch_logprob, batch_adv, lr)

value_loss_epoch += value_loss
action_loss_epoch += action_loss
Expand Down
File renamed without changes.
File renamed without changes.
84 changes: 15 additions & 69 deletions examples/PPO/env_utils.py → examples/PPO/atari/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
import gym
import numpy as np
from parl.utils import logger
from parl.env.atari_wrappers import wrap_deepmind

TEST_EPISODE = 3
# wrapper parameters for atari env
ENV_DIM = 84
OBS_FORMAT = 'NCHW'
# wrapper parameters for mujoco env
GAMMA = 0.99


class ParallelEnv(object):
Expand All @@ -39,14 +37,9 @@ def __init__(self, config=None):
base_env = LocalEnv

if config['seed']:
self.env_list = [
base_env(config['env'], config['seed'] + i)
for i in range(self.env_num)
]
self.env_list = [base_env(config['env'], config['seed'] + i) for i in range(self.env_num)]
else:
self.env_list = [
base_env(config['env']) for _ in range(self.env_num)
]
self.env_list = [base_env(config['env']) for _ in range(self.env_num)]
if hasattr(self.env_list[0], '_max_episode_steps'):
self._max_episode_steps = self.env_list[0]._max_episode_steps
else:
Expand All @@ -68,10 +61,7 @@ def reset(self):
def step(self, action_list):
next_obs_list, reward_list, done_list, info_list = [], [], [], []
if self.use_xparl:
return_list = [
self.env_list[i].step(action_list[i])
for i in range(self.env_num)
]
return_list = [self.env_list[i].step(action_list[i]) for i in range(self.env_num)]
return_list = [return_.get() for return_ in return_list]
return_list = np.array(return_list, dtype=object)

Expand All @@ -89,8 +79,7 @@ def step(self, action_list):
done = done_[i]
info = info_[i]
else:
next_obs, reward, done, info = self.env_list[i].step(
action_list[i])
next_obs, reward, done, info = self.env_list[i].step(action_list[i])

self.episode_steps_list[i] += 1
self.episode_reward_list[i] += reward
Expand All @@ -104,49 +93,26 @@ def step(self, action_list):
next_obs = self.env_list[i].reset()
self.episode_steps_list[i] = 0
self.episode_reward_list[i] = 0
if self.env_list[i].continuous_action:
# get running mean and variance of obs
self.eval_ob_rms = self.env_list[i].env.get_ob_rms()

next_obs_list.append(next_obs)
reward_list.append(reward)
done_list.append(done)
info_list.append(info)
return np.array(next_obs_list), np.array(reward_list), np.array(
done_list), np.array(info_list)
return np.array(next_obs_list), np.array(reward_list), np.array(done_list), np.array(info_list)


class LocalEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)

# is instance of gym.spaces.Box
if hasattr(env.action_space, 'high'):
from parl.env.mujoco_wrappers import wrap_rms
self._max_episode_steps = env._max_episode_steps
self.continuous_action = True
if test:
self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
else:
self.env = wrap_rms(env, gamma=GAMMA)
# is instance of gym.spaces.Discrete
elif hasattr(env.action_space, 'n'):
from parl.env.atari_wrappers import wrap_deepmind
self.continuous_action = False
if hasattr(env.action_space, 'n'):
if test:
self.env = wrap_deepmind(
env,
dim=ENV_DIM,
obs_format=OBS_FORMAT,
test=True,
test_episodes=1)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
self.env = wrap_deepmind(
env, dim=ENV_DIM, obs_format=OBS_FORMAT)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
raise AssertionError(
'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
)
raise AssertionError('act_space must be instance of gym.spaces.Discrete')

self.obs_space = self.env.observation_space
self.act_space = self.env.action_space
Expand All @@ -166,31 +132,13 @@ class RemoteEnv(object):
def __init__(self, env_name, env_seed=None, test=False, ob_rms=None):
env = gym.make(env_name)

if hasattr(env.action_space, 'high'):
from parl.env.mujoco_wrappers import wrap_rms
self._max_episode_steps = env._max_episode_steps
self.continuous_action = True
if test:
self.env = wrap_rms(env, GAMMA, test=True, ob_rms=ob_rms)
else:
self.env = wrap_rms(env, gamma=GAMMA)
elif hasattr(env.action_space, 'n'):
from parl.env.atari_wrappers import wrap_deepmind
self.continuous_action = False
if hasattr(env.action_space, 'n'):
if test:
self.env = wrap_deepmind(
env,
dim=ENV_DIM,
obs_format=OBS_FORMAT,
test=True,
test_episodes=1)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT, test=True, test_episodes=1)
else:
self.env = wrap_deepmind(
env, dim=ENV_DIM, obs_format=OBS_FORMAT)
self.env = wrap_deepmind(env, dim=ENV_DIM, obs_format=OBS_FORMAT)
else:
raise AssertionError(
'act_space must be instance of gym.spaces.Box or gym.spaces.Discrete'
)
raise AssertionError('act_space must be instance of gym.spaces.Discrete')
if env_seed:
self.env.seed(env_seed)

Expand All @@ -201,6 +149,4 @@ def step(self, action):
return self.env.step(action)

def render(self):
return logger.warning(
'Can not render in remote environment, render() have been skipped.'
)
return logger.warning('Can not render in remote environment, render() have been skipped.')
15 changes: 5 additions & 10 deletions examples/PPO/storage.py → examples/PPO/atari/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

class RolloutStorage():
def __init__(self, step_nums, env_num, obs_space, act_space):
self.obs = np.zeros(
(step_nums, env_num) + obs_space.shape, dtype='float32')
self.actions = np.zeros(
(step_nums, env_num) + act_space.shape, dtype='float32')
self.obs = np.zeros((step_nums, env_num) + obs_space.shape, dtype='float32')
self.actions = np.zeros((step_nums, env_num) + act_space.shape, dtype='float32')
self.logprobs = np.zeros((step_nums, env_num), dtype='float32')
self.rewards = np.zeros((step_nums, env_num), dtype='float32')
self.dones = np.zeros((step_nums, env_num), dtype='float32')
Expand Down Expand Up @@ -54,10 +52,8 @@ def compute_returns(self, value, done, gamma=0.99, gae_lambda=0.95):
else:
nextnonterminal = 1.0 - self.dones[t + 1]
nextvalues = self.values[t + 1]
delta = self.rewards[
t] + gamma * nextvalues * nextnonterminal - self.values[t]
advantages[
t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
delta = self.rewards[t] + gamma * nextvalues * nextnonterminal - self.values[t]
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
returns = advantages + self.values
self.returns = returns
self.advantages = advantages
Expand All @@ -72,5 +68,4 @@ def sample_batch(self, idx):
b_returns = self.returns.reshape(-1)
b_values = self.values.reshape(-1)

return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[
idx], b_returns[idx], b_values[idx]
return b_obs[idx], b_actions[idx], b_logprobs[idx], b_advantages[idx], b_returns[idx], b_values[idx]
Loading