Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch RNN support #2172

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .mdlrc

This file was deleted.

54 changes: 54 additions & 0 deletions examples/torch/ppo_pendulum_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
"""This is an example to train a task with PPO algorithm (PyTorch).

Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import torch

from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.torch.algos import PPO
from garage.torch.policies import GaussianMLPPolicy, GaussianGRUPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer


@wrap_experiment
def ppo_pendulum_gru(ctxt=None, seed=1):
"""Train PPO with InvertedDoublePendulum-v2 environment with GRU.

Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.

"""
set_seed(seed)
env = GymEnv('InvertedDoublePendulum-v2')

trainer = Trainer(ctxt)

policy = GaussianGRUPolicy(
env.spec,
hidden_dim= 64,#(64,64), # [64, 64]
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

value_function = GaussianMLPValueFunction(env_spec=env.spec,
hidden_sizes=(32, 32),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

algo = PPO(env_spec=env.spec,
policy=policy,
value_function=value_function,
discount=0.99,
center_adv=False)

trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=10000)


ppo_pendulum_gru(seed=1)
55 changes: 55 additions & 0 deletions examples/torch/trpo_pendulum_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
"""This is an example to train a task with TRPO algorithm (PyTorch).

Here it runs InvertedDoublePendulum-v2 environment with 100 iterations.
"""
import torch

from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.torch.algos import TRPO
from garage.torch.policies import GaussianMLPPolicy, GaussianGRUPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer


@wrap_experiment
def trpo_pendulum_gru(ctxt=None, seed=1):
"""Train TRPO with InvertedDoublePendulum-v2 environment.

Args:
ctxt (garage.experiment.ExperimentContext): The experiment
configuration used by Trainer to create the snapshotter.
seed (int): Used to seed the random number generator to produce
determinism.

"""
set_seed(seed)
env = GymEnv('InvertedDoublePendulum-v2')

trainer = Trainer(ctxt)

policy = GaussianGRUPolicy(env.spec,
hidden_dim=32, # [32, 32]
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)


value_function = GaussianMLPValueFunction(env_spec=env.spec,
hidden_sizes=(32, 32),
hidden_nonlinearity=torch.tanh,
output_nonlinearity=None)

algo = TRPO(env_spec=env.spec,
policy=policy,
value_function=value_function,
# max_path_length=100,
discount=0.99,
center_adv=False)

trainer.setup(algo, env)
trainer.train(n_epochs=100, batch_size=1024)


trpo_pendulum_gru(seed=1)
4 changes: 4 additions & 0 deletions src/garage/torch/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# isort:skip_file
from garage.torch.modules.categorical_cnn_module import CategoricalCNNModule
from garage.torch.modules.cnn_module import CNNModule
from garage.torch.modules.gaussian_gru_module import GaussianGRUModule
from garage.torch.modules.gaussian_mlp_module import (
GaussianMLPIndependentStdModule) # noqa: E501
from garage.torch.modules.gaussian_mlp_module import (
GaussianMLPTwoHeadedModule) # noqa: E501
from garage.torch.modules.gaussian_mlp_module import GaussianMLPModule
from garage.torch.modules.gru_module import GRUModule
from garage.torch.modules.mlp_module import MLPModule
from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule
# DiscreteCNNModule must go after MLPModule
Expand All @@ -23,4 +25,6 @@
'GaussianMLPModule',
'GaussianMLPIndependentStdModule',
'GaussianMLPTwoHeadedModule',
'GaussianGRUModule',
'GRUModule',
]
198 changes: 198 additions & 0 deletions src/garage/torch/modules/gaussian_gru_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Gaussian GRU Module."""
import abc

import torch
from torch import nn
from torch.distributions import Normal
from torch.distributions.independent import Independent
import torch.nn.functional as F

from garage.torch import global_device, set_gpu_mode
from garage.torch.distributions import TanhNormal
from garage.torch.modules.gru_module import GRUModule

class GaussianGRUBaseModule(nn.Module):
"""Gaussian GRU Module.

A model represented by a Gaussian distribution
which is parameterized by a Gated Recurrent Unit (GRU).
"""

def __init__(
self,
input_dim,
output_dim,
hidden_dim=(32, 32),
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
recurrent_nonlinearity=torch.sigmoid,
recurrent_w_init=nn.init.xavier_uniform_,
output_nonlinearity=None,
output_w_init=nn.init.xavier_uniform_,
output_b_init=nn.init.zeros_,
learn_std=True,
init_std=1.0,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super().__init__()
self._input_dim = input_dim
self._output_dim = output_dim
self._hidden_dim = hidden_dim
self._hidden_nonlinearity = hidden_nonlinearity
self._hidden_w_init = hidden_w_init
self._hidden_b_init = hidden_b_init
self._recurrent_nonlinearity = recurrent_nonlinearity
self._recurrent_w_init = recurrent_w_init
self._output_nonlinearity = output_nonlinearity
self._output_w_init = output_w_init
self._output_b_init = output_b_init
self._learn_std = learn_std
self._std_parameterization = std_parameterization,
self._layer_normalization = layer_normalization
self._norm_dist_class = normal_distribution_cls

self.continuous_action_space = True# continuous_action_space
self.log_std_dev = nn.Parameter(init_std * torch.ones(( self._output_dim), dtype=torch.float), requires_grad=self._learn_std)
self.covariance_eye = torch.eye(int(self._output_dim)).unsqueeze(0)

init_std_param = torch.Tensor([init_std]).log()
if self._learn_std:
self._init_std = torch.nn.Parameter(init_std_param)
else:
self._init_std = init_std_param
self.register_buffer('init_std', self._init_std)

def to(self, *args, **kwargs):
"""Move the module to the specified device.

Args:
*args: args to pytorch to function.
**kwargs: keyword args to pytorch to function.

"""
super().to(*args, **kwargs)
buffers = dict(self.named_buffers())
if not isinstance(self._init_std, torch.nn.Parameter):
self._init_std = buffers['init_std']

@abc.abstractmethod
def _get_mean_and_log_std(self, *inputs):
pass

def forward(self, *inputs, terminal=None):
"""Forward method.

Args:
*inputs: Input to the module.

Returns:
torch.distributions.independent.Independent: Independent
distribution.

"""
if torch.cuda.is_available():
set_gpu_mode(True)
else:
set_gpu_mode(False)
device = global_device()

_ , _ , _ , _, policy_logits_out, _ = self._get_mean_and_log_std(*inputs)
if self.continuous_action_space:
cov_matrix = self.covariance_eye.to(device).expand(self._input_dim, self._output_dim, self._output_dim) * torch.exp(self._init_std.to(device))
# We define the distribution on the CPU since otherwise operations fail with CUDA illegal memory access error.
policy_dist = torch.distributions.multivariate_normal.MultivariateNormal(policy_logits_out.to("cpu"), cov_matrix.to("cpu"))
else:
policy_dist = torch.distributions.Categorical(F.softmax(policy_logits_out, dim=1).to("cpu"))

# if not isinstance(policy_dist, TanhNormal):
# # # Makes it so that a sample from the distribution is treated as a
# # # single sample and not dist.batch_shape samples.
# policy_dist = Independent(policy_dist, 1)
return policy_dist

(mean, step_mean, log_std, step_log_std, step_hidden,
hidden_init) = self._get_mean_and_log_std(*inputs)

if self._std_parameterization == 'exp':
std = log_std.exp()
else:
std = log_std.exp().exp().add(1.).log()
dist = self._norm_dist_class(mean, std)
if not isinstance(dist, TanhNormal):
# Makes it so that a sample from the distribution is treated as a
# single sample and not dist.batch_shape samples.
dist = Independent(dist, 1)

return dist
# return (dist, step_mean, step_log_std, step_hidden, hidden_init)


class GaussianGRUModule(GaussianGRUBaseModule):
"""GaussianMLPModule that mean and std share the same network.

"""

def __init__(
self,
input_dim,
output_dim,
hidden_dim=(32, 32),
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
recurrent_nonlinearity=torch.sigmoid,
recurrent_w_init=nn.init.xavier_uniform_,
output_nonlinearity=None,
output_w_init=nn.init.xavier_uniform_,
output_b_init=nn.init.zeros_,
# hidden_state_init=nn.init.zeros_,
learn_std=True,
init_std=1.0,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)

self._mean_gru_module = GRUModule(input_dim=self._input_dim,
output_dim=self._output_dim,
hidden_dim=self._hidden_dim,
layer_dim=1)

def _get_mean_and_log_std(self, *inputs):
"""Get mean and std of Gaussian distribution given inputs.

Args:
*inputs: Input to the module.

Returns:
torch.Tensor: The mean of Gaussian distribution.
torch.Tensor: The variance of Gaussian distribution.

"""
assert len(inputs) == 1
(mean_outputs, step_mean_outputs, step_hidden,
hidden_init_var) = self._mean_gru_module(*inputs)

uncentered_log_std = torch.zeros(inputs[0].size(-1)) + self._init_std

uncentered_step_log_std = torch.zeros(
inputs[0].size(-1)) + self._init_std

return (mean_outputs, step_mean_outputs, uncentered_log_std,
uncentered_step_log_std, step_hidden, hidden_init_var)
47 changes: 47 additions & 0 deletions src/garage/torch/modules/gru_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""GRU in Pytorch."""
import torch
from torch import nn
from torch.autograd import Variable


class GRUModule(nn.Module):

def __init__(
self,
input_dim,
hidden_dim,
# hidden_nonlinearity,
layer_dim,
output_dim,
bias=True):
super().__init__()
self._hidden_dim = hidden_dim
# Number of hidden layers
self._layer_dim = layer_dim
print(input_dim, hidden_dim)
self._gru_cell = nn.GRUCell(input_dim, hidden_dim)
# self.gru_cell = GRUCell(input_dim, hidden_dim, layer_dim)
self._fc = nn.Linear(hidden_dim, output_dim)

def forward(self, *input):
input = Variable(input[0].view(-1, input[0].size(0), input[0].size(1)))

# Initialize hidden state with zeros
if torch.cuda.is_available():
h0 = Variable(
torch.zeros(self._layer_dim, input.size(0),
self._hidden_dim).cuda())
else:
h0 = Variable(
torch.zeros(self._layer_dim, input.size(0), self._hidden_dim))
outs = []
hn = h0[0, :, :]

for seq in range(input.size(1)):
hn = self._gru_cell(input[:, seq, :], hn)
outs.append(hn)
out = outs[-1].squeeze()
out = self._fc(out)
outs = torch.stack(outs) # convert list of tensors to tensor
outs = self._fc(outs)
return outs, out, hn, h0
2 changes: 2 additions & 0 deletions src/garage/torch/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DeterministicMLPPolicy)
from garage.torch.policies.discrete_qf_argmax_policy import (
DiscreteQFArgmaxPolicy)
from garage.torch.policies.gaussian_gru_policy import GaussianGRUPolicy
from garage.torch.policies.gaussian_mlp_policy import GaussianMLPPolicy
from garage.torch.policies.policy import Policy
from garage.torch.policies.tanh_gaussian_mlp_policy import (
Expand All @@ -15,6 +16,7 @@
'CategoricalCNNPolicy',
'DeterministicMLPPolicy',
'DiscreteQFArgmaxPolicy',
'GaussianGRUPolicy',
'GaussianMLPPolicy',
'Policy',
'TanhGaussianMLPPolicy',
Expand Down
Loading