Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft for modular Hindsight Experience Replay Transform #2667

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

dtsaras
Copy link

@dtsaras dtsaras commented Dec 19, 2024

Description

I have added the Hindsight Experience Replay Transform specifically implementing the future and last strategy as described in the paper. The transform is a combination of 3 transforms:

  • HERSubGoalSampler: It's responsible for sampling indexes for the subgoals and can be changed with another subgoal sampling method for the specific use case.
  • HERSubGoalAssigner: It's the method responsible for creating new trajectories given the subgoal indices.
  • HERRewardTransform: While it might not necessarily need to be a separate transform yet, it's the method responsible for reassigning the rewards to the newly generated trajectories.

Motivation and Context

It's a modular implementation for hindsight experience replay as requested #1819

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2667

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2024
@vmoens vmoens added the enhancement New feature or request label Jan 8, 2025
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me! Thanks for working on this.

I'd love to see some tests to understand better how this all works.
I left some comments here and there, mostly about formatting and high-level design decisions. Happy to give it a more thorough technical look later once there's an example of how to run it and/or some tests to rely upon.

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's create a dedicated file for these?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give the command on where you would like me to put these and I will do it.

torchrl/envs/transforms/transforms.py Show resolved Hide resolved
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)

batch_size, trajectory_len = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_len = trajectories.shape
*batch_size, trajectory_len = trajectories.shape

to account for batch size > 2

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment I assume that we have a single trajectory or a batch of trajectories [b, t]. I am not sure what other cases there may be, but we can think about it.

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
batch_size, trajectory_length = trajectories.shape
*batch_size, trajectory_length = trajectories.shape

# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(batch_size):
for i in range(batch_size.numel()):

which also works with batch_size=torch.Size([])!

return trajectories


class HindsightExperienceReplayTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward.

If you look at Compose, there are a bunch of things that need to be implemented when nesting transforms, like clone, cache eraser etc.

Perhaps we could inherit from Compose and rewrite forward, _apply_transform, _call, _reset etc such that the logic hold but the extra features are included automatically?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a method that we do not need to attach to an environment but it's a data augmentation method. The gist of the augmentation is: Given a trajectory we sample some intermediate states and assume that they are the goal instead. Thus, we can get some positive rewards for hard cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants