Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft for modular Hindsight Experience Replay Transform #2667

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions torchrl/envs/transforms/transforms.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's create a dedicated file for these?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give the command on where you would like me to put these and I will do it.

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TensorDictBase,
unravel_key,
unravel_key_list,
pad_sequence,
)
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
Expand Down Expand Up @@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
high=torch.iinfo(torch.int64).max,
)
return super().transform_observation_spec(observation_spec)


class HERSubGoalSampler(Transform):
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(
self,
num_samples: int = 4,
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.num_samples = num_samples
self.subgoal_idx_key = subgoal_idx_key
self.strategy = strategy

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
if len(trajectories.shape) == 1:
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
trajectories = trajectories.unsqueeze(0)

batch_size, trajectory_len = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_len = trajectories.shape
*batch_size, trajectory_len = trajectories.shape

to account for batch size > 2

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment I assume that we have a single trajectory or a batch of trajectories [b, t]. I am not sure what other cases there may be, but we can think about it.


if self.strategy == "last":
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size)
dtsaras marked this conversation as resolved.
Show resolved Hide resolved

else:
subgoal_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

for batch_size with more than one dim

subgoal_idxs.append(
TensorDict(
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
batch_size=torch.Size(),
)
)
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True)


class HERSubGoalAssigner(Transform):
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
SHOULD BE achieved_goal_key??? ===> subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
SHOULD BE desired_goal_key?? ===> subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a .. seealso:: with other related classes.

def __init__(
self,
achieved_goal_key: str = "achieved_goal",
desired_goal_key: str = "desired_goal",
):
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
self.achieved_goal_key = achieved_goal_key
self.desired_goal_key = desired_goal_key

def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase:
batch_size, trajectory_len = trajectories.shape
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's a vectorized version of this? The ops seem simple enough to be executed in a vectorized way

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I had given it a shot with vmap but indexing is not well supported with vmap. Once we pin down the API, I can give it a shot again.

subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key]
desired_goal_shape = trajectories[i][self.desired_goal_key].shape
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape)
trajectories[i][subgoals_idxs[i]]["next", "done"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we keep the loop, I'd rather have trajectories.unbind(0) than indexing every element along dim 0, it will be faster

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we finalize the API, I will optimize things further.

# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True

return trajectories


class HERRewardTransform(Transform):
"""This module assigns the reward to the trajectory according to the new subgoal.
Args:
reward_name (str): The key to the reward. Defaults to "reward".
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
"""
def __init__(
self
):
pass

def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
return trajectories


class HindsightExperienceReplayTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward.

If you look at Compose, there are a bunch of things that need to be implemented when nesting transforms, like clone, cache eraser etc.

Perhaps we could inherit from Compose and rewrite forward, _apply_transform, _call, _reset etc such that the logic hold but the extra features are included automatically?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a method that we do not need to attach to an environment but it's a data augmentation method. The gist of the augmentation is: Given a trajectory we sample some intermediate states and assume that they are the goal instead. Thus, we can get some positive rewards for hard cases.

"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform):
SubGoalAssigner (Transform):
RewardTransform (Transform):
"""
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
self.SubGoalSampler = SubGoalSampler
self.SubGoalAssigner = SubGoalAssigner
self.RewardTransform = RewardTransform
dtsaras marked this conversation as resolved.
Show resolved Hide resolved
self.assign_subgoal_idxs = assign_subgoal_idxs

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
augmentation_td = self.her_augmentation(tensordict)
return torch.cat([tensordict, augmentation_td], dim=0)

def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
return self.her_augmentation(tensordict)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(self.ENV_ERR)

def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_length = trajectories.shape
*batch_size, trajectory_length = trajectories.shape


new_trajectories = trajectories.clone(True)

# Sample subgoal indices
subgoal_idxs = self.SubGoalSampler(new_trajectories)

# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

which also works with batch_size=torch.Size([])!

idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key]

if "masks" in subgoal_idxs.keys():
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]]

list_idxs.append(idxs.unsqueeze(-1))
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True)

if self.assign_subgoal_idxs:
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length)

augmented_trajectories.append(new_traj)
augmented_trajectories = torch.cat(augmented_trajectories, dim=0)
associated_idxs = torch.cat(list_idxs, dim=0)

# Assign subgoals to the new trajectories
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs)

# Adjust the rewards based on the new subgoals
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories)

return augmented_trajectories