Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example on how to add a new device out of tree #36

Merged
merged 3 commits into from
Jun 3, 2022
Merged

Conversation

albanD
Copy link
Owner

@albanD albanD commented Jun 2, 2022

This is composed of:

  • An "has a" torch dispatch class that holds onto the user raw data (a numpy array here for simplicity, but it can be any python/c++ object).
  • A torch function mode to properly capture:
    • factory functions and handle the newly added device.
    • cross device functions

Fixes #35

new_device.py Outdated
@staticmethod
def __new__(cls, size, dtype, raw_data=None, requires_grad=False):
# Use a meta Tensor here to be used as the wrapper
return torch.Tensor._make_subclass(cls, torch.empty(size, dtype=dtype, device="meta"), require_grad=requires_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No way to set the device alas!

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, all the C++ code that sets "device" will make sure it is one of the ones in our hardcoded list :(


def __repr__(self):
st = super().__repr__()
st = st.replace("device='meta'", "device='my_device'")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks more confusing than hell ha ha ha

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this most likely will need to be tweaked by real backends. This is just to be able to easily re-use the plain Tensor print.

new_device.py Outdated
except Exception as e:
print(e)
raise e
raise RuntimeError(f"No implementation for 'my_device' for {func}, {args}, {kwargs}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use some of the tech from @eellison's fake tensor you should also be able to get the device property to report correctly as well

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked into that one yet.
The problem is that even if that tries to create a torch.device based on the result from c++, that will fail.

@ezyang
Copy link
Collaborator

ezyang commented Jun 2, 2022

cc @bdhirsh

@albanD albanD merged commit 55d6be0 into main Jun 3, 2022
@albanD albanD deleted the backend_tensor branch June 3, 2022 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

What's the right way to use torch dispatch?
2 participants