-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathverifier_tensor.py
181 lines (147 loc) · 5.72 KB
/
verifier_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
from base_tensor import BaseTensor
from torch.fx import Interpreter, Node
from torch.testing._internal.common_utils import run_tests, TestCase
from tracer_tensor import dispatch_trace
# https://github.com/albanD/subclass_zoo/blob/33d7afe63c2a336e01eaf3e81fba085a68e3955f/bug_zoo.py#L18-L24
# how to do speculate and validate
# - need a function under trace (dispatch_trace)
# - first time run with normal TracerTensor
# - second time run with VerifierTensor
# recovery is not necessary
class Verifier:
def __init__(self, interpreter, node):
self.node = node
# We aren't actually going to run the interpreter, it's just
# here for fetch_attr
self.interpreter = interpreter
# TODO: IDK if there's a better way to do this
self.constant_map = {}
def advance(self):
node = self.node
self.node = node.next
# Whenever constant nodes show up, FX will give these get_attr nodes.
# When we're verifying torch dispatch calls this is not relevant,
# but we do need to know about these so that we can appropriately
# check if the user is reusing the correct constants.
while node.op == "get_attr":
self.constant_map[self.interpreter.fetch_attr(node.target)] = node
node = self.node
self.node = node.next
return node
def constant_node(self, t):
return self.constant_map[t]
VERIFIER = None
class VerifierTensor(BaseTensor):
@staticmethod
def __new__(cls, elem, node):
return super().__new__(cls, elem)
def __init__(self, elem, node):
self.node = node
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# Verify that this is correct
node = VERIFIER.advance()
assert node.op == "call_function", node.op
assert node.target == func
def translate(n, v):
if isinstance(n, Node):
if isinstance(v, VerifierTensor):
assert n is v.node
return v
else:
assert n is VERIFIER.constant_node(v)
# Need to translate constants to meta so that
# we satisfy device checks
return v.to("meta")
else:
assert n == v
return v
meta_args = []
meta_kwargs = {}
for i, n in enumerate(node.args):
meta_args.append(translate(n, args[i]))
for k, n in node.kwargs.items():
meta_kwargs[k] = translate(n, kwargs[k])
assert len(node.kwargs) == len(kwargs)
r = super().__torch_dispatch__(func, types, tuple(meta_args), meta_kwargs)
# For the multi-outputs need to advance verifier past the indexing
# nodes
if isinstance(r, list):
raise NotImplementedError
elif isinstance(r, tuple):
raise NotImplementedError
else:
return VerifierTensor(r, node)
class SpeculatingJit:
def __init__(self, root):
self.root = root
self.graph = None
self.interpreter = None
def transform(self, graph):
return graph
def __call__(self, *args):
if self.graph is None:
r, self.graph = dispatch_trace(self.root, args)
self.interpreter = Interpreter(self.transform(self.graph))
return r
else:
# assume the placeholder nodes are first
# TODO: there is a problem with the verifier design here which
# is that it is not possible to free constants that are captured
# by the graph, which might be important for memory usage
# if FX transformation did weight transformation. I think what
# you want to do is stub out the tensors with meta "shadows"
# that have a correspondence to getattr nodes but it is a little
# fiddly to implement
global VERIFIER
VERIFIER = Verifier(
Interpreter(self.graph), next(iter(self.graph.graph.nodes))
)
i = 0
verifier_args = []
for a in args:
n = VERIFIER.advance()
assert n.op == "placeholder"
verifier_args.append(VerifierTensor(a.to("meta"), n))
r = self.interpreter.run(*args)
verifier_r = self.root(*verifier_args)
VERIFIER = None
assert r.shape == verifier_r.shape
assert r.dtype == verifier_r.dtype
return r
class VerifierTensorTest(TestCase):
def test_basic(self):
def root(x, y):
# TODO: x + y is annoying to debug because the exception gets
# swallowed
return torch.add(x, y)
f = SpeculatingJit(root)
r = f(torch.zeros(2), torch.zeros(2))
self.assertEqual(r, torch.zeros(2))
r2 = f(torch.ones(2), torch.zeros(2))
self.assertEqual(r2, torch.ones(2))
def test_constant(self):
x = torch.zeros(2)
def root(y):
return torch.add(x, y)
f = SpeculatingJit(root)
r = f(torch.zeros(2))
self.assertEqual(r, torch.zeros(2))
r2 = f(torch.ones(2))
self.assertEqual(r2, torch.ones(2))
def test_validation_failure(self):
i = 0
def root(x, y):
nonlocal i
i += 1
if i == 1:
return torch.add(x, y)
else:
return torch.mul(x, y)
f = SpeculatingJit(root)
r = f(torch.zeros(2), torch.zeros(2))
self.assertEqual(r, torch.zeros(2))
self.assertRaises(AssertionError, lambda: f(torch.ones(2), torch.zeros(2)))
if __name__ == "__main__":
run_tests()