-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First draft for modular Hindsight Experience Replay Transform #2667
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ | |||||||||||||||||||||||
TensorDictBase, | ||||||||||||||||||||||||
unravel_key, | ||||||||||||||||||||||||
unravel_key_list, | ||||||||||||||||||||||||
pad_sequence, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
from tensordict.nn import dispatch, TensorDictModuleBase | ||||||||||||||||||||||||
from tensordict.utils import ( | ||||||||||||||||||||||||
|
@@ -9264,3 +9265,165 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: | |||||||||||||||||||||||
high=torch.iinfo(torch.int64).max, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return super().transform_observation_spec(observation_spec) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class HERSubGoalSampler(Transform): | ||||||||||||||||||||||||
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | ||||||||||||||||||||||||
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | ||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | ||||||||||||||||||||||||
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
num_samples: int = 4, | ||||||||||||||||||||||||
subgoal_idx_key: str = "subgoal_idx", | ||||||||||||||||||||||||
strategy: str = "future" | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
): | ||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||
in_keys=None, | ||||||||||||||||||||||||
in_keys_inv=None, | ||||||||||||||||||||||||
out_keys_inv=None, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
self.num_samples = num_samples | ||||||||||||||||||||||||
self.subgoal_idx_key = subgoal_idx_key | ||||||||||||||||||||||||
self.strategy = strategy | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
if len(trajectories.shape) == 1: | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
trajectories = trajectories.unsqueeze(0) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
batch_size, trajectory_len = trajectories.shape | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
to account for batch size > 2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment I assume that we have a single trajectory or a batch of trajectories [b, t]. I am not sure what other cases there may be, but we can think about it. |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.strategy == "last": | ||||||||||||||||||||||||
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
subgoal_idxs = [] | ||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess
Suggested change
for batch_size with more than one dim |
||||||||||||||||||||||||
subgoal_idxs.append( | ||||||||||||||||||||||||
TensorDict( | ||||||||||||||||||||||||
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
batch_size=torch.Size(), | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class HERSubGoalAssigner(Transform): | ||||||||||||||||||||||||
"""This module assigns the subgoal to the trajectory according to a given subgoal index. | ||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". | ||||||||||||||||||||||||
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a |
||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
achieved_goal_key: str = "achieved_goal", | ||||||||||||||||||||||||
desired_goal_key: str = "desired_goal", | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
self.achieved_goal_key = achieved_goal_key | ||||||||||||||||||||||||
self.desired_goal_key = desired_goal_key | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: | ||||||||||||||||||||||||
batch_size, trajectory_len = trajectories.shape | ||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there's a vectorized version of this? The ops seem simple enough to be executed in a vectorized way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I had given it a shot with vmap but indexing is not well supported with vmap. Once we pin down the API, I can give it a shot again. |
||||||||||||||||||||||||
subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] | ||||||||||||||||||||||||
desired_goal_shape = trajectories[i][self.desired_goal_key].shape | ||||||||||||||||||||||||
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) | ||||||||||||||||||||||||
trajectories[i][subgoals_idxs[i]]["next", "done"] = True | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we keep the loop, I'd rather have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we finalize the API, I will optimize things further. |
||||||||||||||||||||||||
# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return trajectories | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class HERRewardTransform(Transform): | ||||||||||||||||||||||||
"""This module assigns the reward to the trajectory according to the new subgoal. | ||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
reward_name (str): The key to the reward. Defaults to "reward". | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
pass | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
return trajectories | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class HindsightExperienceReplayTransform(Transform): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to modify the specs? If you look at Perhaps we could inherit from Compose and rewrite There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a method that we do not need to attach to an environment but it's a data augmentation method. The gist of the augmentation is: Given a trajectory we sample some intermediate states and assume that they are the goal instead. Thus, we can get some positive rewards for hard cases. |
||||||||||||||||||||||||
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. | ||||||||||||||||||||||||
This module is a wrapper that includes the following modules: | ||||||||||||||||||||||||
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. | ||||||||||||||||||||||||
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. | ||||||||||||||||||||||||
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal. | ||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||
SubGoalSampler (Transform): | ||||||||||||||||||||||||
SubGoalAssigner (Transform): | ||||||||||||||||||||||||
RewardTransform (Transform): | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
SubGoalSampler: Transform = HERSubGoalSampler(), | ||||||||||||||||||||||||
SubGoalAssigner: Transform = HERSubGoalAssigner(), | ||||||||||||||||||||||||
RewardTransform: Transform = HERRewardTransform(), | ||||||||||||||||||||||||
assign_subgoal_idxs: bool = False, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||
in_keys=None, | ||||||||||||||||||||||||
in_keys_inv=None, | ||||||||||||||||||||||||
out_keys_inv=None, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
self.SubGoalSampler = SubGoalSampler | ||||||||||||||||||||||||
self.SubGoalAssigner = SubGoalAssigner | ||||||||||||||||||||||||
self.RewardTransform = RewardTransform | ||||||||||||||||||||||||
dtsaras marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
self.assign_subgoal_idxs = assign_subgoal_idxs | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
augmentation_td = self.her_augmentation(tensordict) | ||||||||||||||||||||||||
return torch.cat([tensordict, augmentation_td], dim=0) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: | ||||||||||||||||||||||||
return self.her_augmentation(tensordict) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
return tensordict | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||||||||||||||||||||
raise ValueError(self.ENV_ERR) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def her_augmentation(self, trajectories: TensorDictBase): | ||||||||||||||||||||||||
if len(trajectories.shape) == 1: | ||||||||||||||||||||||||
trajectories = trajectories.unsqueeze(0) | ||||||||||||||||||||||||
batch_size, trajectory_length = trajectories.shape | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
new_trajectories = trajectories.clone(True) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Sample subgoal indices | ||||||||||||||||||||||||
subgoal_idxs = self.SubGoalSampler(new_trajectories) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Create new trajectories | ||||||||||||||||||||||||
augmented_trajectories = [] | ||||||||||||||||||||||||
list_idxs = [] | ||||||||||||||||||||||||
for i in range(batch_size): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
which also works with |
||||||||||||||||||||||||
idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if "masks" in subgoal_idxs.keys(): | ||||||||||||||||||||||||
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
list_idxs.append(idxs.unsqueeze(-1)) | ||||||||||||||||||||||||
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.assign_subgoal_idxs: | ||||||||||||||||||||||||
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
augmented_trajectories.append(new_traj) | ||||||||||||||||||||||||
augmented_trajectories = torch.cat(augmented_trajectories, dim=0) | ||||||||||||||||||||||||
associated_idxs = torch.cat(list_idxs, dim=0) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Assign subgoals to the new trajectories | ||||||||||||||||||||||||
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Adjust the rewards based on the new subgoals | ||||||||||||||||||||||||
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return augmented_trajectories |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's create a dedicated file for these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give the command on where you would like me to put these and I will do it.