use devices/dtypes based on passed in tensors #68
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi,
Thanks for putting together this repository. I'm using the loss functions only as a part of another project and using CPU/GPU at different times. I also use half precision training sometimes. This PR makes the loss functions use the device/dtypes of the passed in tensors rather than always using GPU/torch.float32. Since the training code still uses
get_torch_device()
and float32 tensors this should change the operation only when someone is using the loss functions separately (my use case)Creating this PR in case it's useful to others. Obviously feel free to reject if you want to keep as is.