-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathempty_tensor.py
62 lines (50 loc) · 1.6 KB
/
empty_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from base_tensor import BaseTensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_map
from utils import no_dispatch
class EmptyTensor(BaseTensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
requires_grad=elem.requires_grad,
device=elem.device,
)
def __init__(self, elem):
pass
def __repr__(self):
# TODO: this is wrong
return f"EmptyTensor({self.size()})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def inflate(t):
if isinstance(t, cls):
with no_dispatch():
return torch.ones_like(t, device=t.device)
else:
return t
def deflate(t):
if isinstance(t, torch.Tensor) and not isinstance(t, cls):
return EmptyTensor(t)
else:
return t
return tree_map(
deflate,
super().__torch_dispatch__(
func, types, tree_map(inflate, args), tree_map(inflate, kwargs)
),
)
class EmptyTensorTest(TestCase):
def test_basic(self):
x = EmptyTensor(torch.randn(4))
y = EmptyTensor(torch.randn(4))
r = x + y
self.assertEqual(r.shape, (4,))
if __name__ == "__main__":
run_tests()