-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First draft for modular Hindsight Experience Replay Transform #2667
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2667
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me! Thanks for working on this.
I'd love to see some tests to understand better how this all works.
I left some comments here and there, mostly about formatting and high-level design decisions. Happy to give it a more thorough technical look later once there's an example of how to run it and/or some tests to rely upon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe let's create a dedicated file for these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give the command on where you would like me to put these and I will do it.
if len(trajectories.shape) == 1: | ||
trajectories = trajectories.unsqueeze(0) | ||
|
||
batch_size, trajectory_len = trajectories.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
batch_size, trajectory_len = trajectories.shape | |
*batch_size, trajectory_len = trajectories.shape |
to account for batch size > 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment I assume that we have a single trajectory or a batch of trajectories [b, t]. I am not sure what other cases there may be, but we can think about it.
def her_augmentation(self, trajectories: TensorDictBase): | ||
if len(trajectories.shape) == 1: | ||
trajectories = trajectories.unsqueeze(0) | ||
batch_size, trajectory_length = trajectories.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
batch_size, trajectory_length = trajectories.shape | |
*batch_size, trajectory_length = trajectories.shape |
# Create new trajectories | ||
augmented_trajectories = [] | ||
list_idxs = [] | ||
for i in range(batch_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in range(batch_size): | |
for i in range(batch_size.numel()): |
which also works with batch_size=torch.Size([])
!
return trajectories | ||
|
||
|
||
class HindsightExperienceReplayTransform(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to modify the specs?
Does this work with replay buffer (static data) or only envs? If the latter, we should not be using forward
.
If you look at Compose
, there are a bunch of things that need to be implemented when nesting transforms, like clone
, cache eraser etc.
Perhaps we could inherit from Compose and rewrite forward
, _apply_transform
, _call
, _reset
etc such that the logic hold but the extra features are included automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a method that we do not need to attach to an environment but it's a data augmentation method. The gist of the augmentation is: Given a trajectory we sample some intermediate states and assume that they are the goal instead. Thus, we can get some positive rewards for hard cases.
Description
I have added the Hindsight Experience Replay Transform specifically implementing the
future
andlast
strategy as described in the paper. The transform is a combination of 3 transforms:Motivation and Context
It's a modular implementation for hindsight experience replay as requested #1819
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!