-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] info['_weight'] device for Importance Sampling in PER #2518
Comments
That and also we should be able to execute this directly on device. I'll push some changes |
Just FYI you could do this instead: # From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler, TensorDictReplayBuffer
from tensordict import TensorDict
import torch
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample = rb.sample(10)
# Check devices
print(f"sample device: {sample.device}\n"
f"sample['_weight'] device: {sample['_weight'].device}") which will put your weights on cuda. There are two issues in patching the PRB to account for the device of the storage:
# From documentation
import functools
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
import torch
device = "cuda"
# patch
sample = PrioritizedSampler.sample
@functools.wraps(sample)
def new_sample(self, *args, **kwargs):
out = sample(self, *args, **kwargs)
out = torch.utils._pytree.tree_map(lambda x: x.to(device), out)
return out
PrioritizedSampler.sample = new_sample
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device(device)), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
# map back content on cpu
rb.append_transform(lambda x: x.to("cpu"))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)
# Check devices
print(f"sample device: {sample.device}\n"
f"info['_weight'] device: {info['_weight'].device}") So to recap: |
Maybe I'm missing something, but def sample(self, storage: Storage, batch_size: int) accepts the storage as an argument, thus we can query
That's a valid point. I wanted to suggest adding |
Describe the bug
The device of
info['_weight']
doesn't match the storage device.To Reproduce
sample device: cuda:0 info['_weight'] device: cpu
Expected behavior
Both should be on the same device defined in
storage(..., device)
as these weights are later used to compute the loss.System info
Reason and Possible fixes
Specify
device
argument in samplers.py (L508):Checklist
The text was updated successfully, but these errors were encountered: