forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnative_parse.py
407 lines (365 loc) · 18.1 KB
/
native_parse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
from __future__ import print_function
import re
import yaml
import pprint
import sys
import copy
try:
# use faster C loader if available
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
# [temp translations]
# We're currently incrementally moving from the custom func schema to the
# JIT signature schema incrementally. This will reduce overall complexity
# and increase compliance between these components. So for now we do simple
# type translations to continue to emit the legacy func schema for further
# processing by downstream tools. This will helps us avoid having to prematurely
# change all downstream tools to detect these new types.
def type_argument_translations(arg):
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
name = ''
if len(type_and_name) > 1:
name = type_and_name[1]
t = type_and_name[0]
name = name.split('=')
default = None
nullable = False
size = None # Only applies to int[\d+] and Tensor[\d+] arguments
if len(name) > 1:
default = name[1]
name = name[0]
match = re.match(r'(Tensor.*)\((.+)\)(.*)', t)
annotation = None
if match:
t = match.group(1) + match.group(3)
annotation = match.group(2)
# XXX: is_nullable flag can only annotate entire type as optional type,
# need to special case Generator? logic to make ? only available in jit
# TODO: deprecate is_nullable global flag, and parse the type
# to support annotating complicated types with optional annotation
nullable = (t != 'Generator?' and '?' in t)
# This enables "Generator? x = None and translates to legacy
# "Generator* x = nullptr". See [temp translations].
if t == 'Generator?' and default == 'None':
t = 'Generator*'
default = 'nullptr'
# Enables Generator? by translating to legacy Generator*.
elif t == "Generator?":
t = 'Generator*'
# Enables Tensor[] by translating to legacy TensorList.
elif t == 'Tensor[]' or t == 'Tensor?[]':
t = 'TensorList'
# Enables int[] by translating to legacy IntArrayRef.
elif t == 'int[]':
t = 'IntArrayRef'
# Enables int by translating to legacy int64_t.
elif t == 'int':
t = 'int64_t'
elif t == 'int?':
t = 'int64_t?'
elif t == 'int64_t':
raise RuntimeError("Please use int and not int64_t. "
"See [temp translations] for details.")
elif t == 'int64_t?':
raise RuntimeError("Please use int? and not int64_t?. "
"See [temp translations] for details.")
# Enables float by translating to legacy double.
elif t == 'float':
t = 'double'
# Enables str by translating to legacy std::string.
elif t == 'str':
t = 'std::string'
elif t == 'double':
raise RuntimeError("Please use float and not double. "
"See [temp translations] for details.")
# Enables int[x] by translating to legacy IntArrayRef[x]. See [temp translations]
elif re.match(r'int\[(\d+)\]', t):
match = re.match(r'int\[(\d+)\]', t)
t = 'IntArrayRef'
size = int(match.group(1))
# Enables bool[x] by translating to legacy std::array<bool,x>. See [temp translations]
elif re.match(r'bool\[(\d+)\]', t):
match = re.match(r'bool\[(\d+)\]', t)
t = 'std::array<bool,{}>'.format(match.group(1))
elif re.match(r'std::array', t):
raise RuntimeError("Please use array notation, e.g. bool[3] and not std::array."
"See [temp translations] for details.")
# Legacy type sanitization. TODO: Do we really need this?
if t == 'Generator*':
t = 'Generator *'
if not default:
pass
# This enables Tensor? x=None and translates to legacy
# "Tensor? x={}". See [temp translations].
elif t.startswith('Tensor?') and default == 'None':
default = "{}"
elif default == 'True':
default = True
elif default == 'False':
default = False
elif default == 'true':
raise RuntimeError("Please use True and not true. "
"See [temp translations] for details.")
elif default == 'false':
raise RuntimeError("Please use False and not false. "
"See [temp translations] for details.")
# Enables default argument [] by translating to legacy {}.
# See [temp translations]
elif default == '[]':
default = '{}'
# Enables lists by translating to legacy {.*}.
# See [temp translations]
elif re.match(r'\[.*\]', default):
default = "{" + default[1:-1] + "}"
elif default == 'None':
default = 'c10::nullopt'
# The JIT signature schema uses Mean, but in particular C++ needs
# the legacy Reduction::Mean. So we'll continue emiting that until
# we change this at either a JIT schema or C++ level.
elif default == 'Mean':
default = 'Reduction::Mean'
else:
try:
default = int(default)
except ValueError:
try:
default = float(default)
except ValueError:
pass
return t, name, default, nullable, size, annotation
def parse_arguments(args, func_variants, declaration, func_return):
arguments = []
kwarg_only = False
if len(args.strip()) == 0:
return arguments
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
for arg_idx, arg in enumerate(args.split(', ')):
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
if type_and_name == ['*']:
assert not kwarg_only
kwarg_only = True
continue
t, name, default, nullable, size, annotation = type_argument_translations(arg)
argument_dict = {'type': t.rstrip('?'), 'name': name, 'is_nullable': nullable, 'annotation': annotation}
if size:
argument_dict['size'] = size
if default is not None:
argument_dict['default'] = default
if kwarg_only:
argument_dict['kwarg_only'] = True
arguments.append(argument_dict)
is_out_fn = False
arguments_out = []
arguments_other = []
for argument in arguments:
if argument['type'] == "Tensor" and \
argument['annotation'] and \
re.match(r'^(.*!)$', argument['annotation']) and \
argument.get('kwarg_only'):
argument['output'] = True
argument['kwarg_only'] = False
arguments_out.append(argument)
is_out_fn = True
else:
arguments_other.append(argument)
arguments = arguments_out + arguments_other
name = declaration['name']
if is_out_fn:
declaration['name'] += "_out"
# Reverse splat of TensorOptions
# As we move towards the JIT function schema for native_functions.yaml we need to support
# the expanded version of TensorOptions. For now we discover whether there are three
# types and names of keyword arguments: "ScalarType dtype", "Layout layout" and "Device device"
# Each, if set, must have default arguments set to long or float, strided and "cpu" respectively.
# They must appear in this order and in this order only in order for us to be able to process them.
# In the future we will get rid of this specific processing as downstream consumers start relying
# less on the content of Declarations.yaml. If you want to support more than this you'll
# potentially have to extend the JIT.
supported_topt_arguments = [
[
{'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None},
{'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None},
{'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None},
{'name': 'pin_memory', 'type': 'bool', 'is_nullable': False, 'annotation': None, 'default': False},
]
]
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
for arg in supported_topt_arguments[1]:
arg.update({'kwarg_only': True})
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
for arg in supported_topt_arguments[2]:
arg.update({'default': 'c10::nullopt', 'is_nullable': True})
# add explicit support for what is needed for tril_indices / triu_indices
supported_topt_arguments.append(
[
{'name': 'dtype', 'type': 'ScalarType', 'annotation': None, 'kwarg_only': True,
'default': 'long', 'is_nullable': True},
{'name': 'layout', 'type': 'Layout', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
{'name': 'device', 'type': 'Device', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
{'name': 'pin_memory', 'type': 'bool', 'annotation': None, 'kwarg_only': True,
'default': 'c10::nullopt', 'is_nullable': True},
]
)
corresponding_topts = [
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None},
]
corresponding_topts.append(corresponding_topts[0].copy())
corresponding_topts[1]['kwarg_only'] = True
corresponding_topts.append(corresponding_topts[1].copy())
corresponding_topts[2]['default'] = '{}'
corresponding_topts.append(
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None,
'kwarg_only': True, 'default': 'at::kLong'})
def check_topt_representation(topt_representation):
for idx, supported_topt in enumerate(supported_topt_arguments):
matches = all(topt_representation[i] == topt for i, topt in enumerate(supported_topt))
if matches:
return corresponding_topts[idx]
return None
def is_tensor_option(argument):
return argument['name'] in ['dtype', 'layout', 'device', 'pin_memory']
new_arguments = []
idx = 0
while idx < len(arguments):
argument = arguments[idx]
number_of_arguments = len(supported_topt_arguments[0])
if is_tensor_option(argument) and len(arguments) - idx >= number_of_arguments:
topt_representation = []
for i in range(number_of_arguments):
argument = arguments[idx]
if not is_tensor_option(argument):
break
topt_representation.append(argument)
idx += 1
if len(topt_representation) == number_of_arguments:
merged_argument = check_topt_representation(topt_representation)
assert merged_argument, \
"Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
.format(str(topt_representation), str(supported_topt_arguments))
new_arguments.append(merged_argument)
else:
new_arguments += topt_representation
else:
new_arguments.append(argument)
idx += 1
arguments = new_arguments
# Sanity checks
# TODO: convention is that the ith-argument correspond to the i-th return, but it would
# be better if we just named everything and matched by name.
for arg_idx, argument in enumerate(arguments_out):
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
"For func {} writeable keyword Tensor arguments need to have a matching return Tensor. Further, " \
"the ith-argument needs to correspond to the i-th return.".format(name)
assert len(arguments_out) <= len(func_return), "func {} must return at least as many Tensors " \
"as can be passed as output.".format(name)
if name.endswith('_out'):
raise RuntimeError("Native function {} may not be suffixed with _out as we transition to a unified schema. "
"Otherwise you will cause confusion amongst consumers of native functions.".format(name))
if is_out_fn and func_variants not in [[], 'function', ['function']]:
raise RuntimeError("Native functions with output MUST be declared with only the function variant; "
"e.g., variants: function; otherwise you will tickle a Python argument binding bug "
"(which usually manifests itself as the result variable being undefined.) "
"The culprit was: {}".format(name))
if not is_out_fn:
assert len(arguments_out) == 0, "func {} is not marked as output yet contains output " \
"keyword arguments".format(name)
# TODO: Explicit checking for void is a hack and should disappear after a more
# functionally complete implementation of Tensor aliases.
if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void":
found_self = False
for arg_idx, argument in enumerate(arguments):
if argument['name'] == "self":
assert argument['annotation'] and argument['annotation'].endswith("!"), \
"Inplace function \"{}\" needs to annotate Tensor argument named self " \
"as mutable.".format(name)
found_self = True
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
"Inplace function annotations of function {} need to match between " \
"input and correponding output.".format(name)
assert argument['name'] == func_return[arg_idx]['name'] or \
argument['name'] == func_return[arg_idx]['name'] + "_return"
assert argument['type'] == func_return[arg_idx]['type']
assert found_self, "Inplace function \"{}\" needs Tensor argument named self.".format(name)
return arguments
def parse_return_arguments(return_decl, inplace, func_decl):
arguments = []
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
if return_decl[0] == '(' and return_decl[-1] == ')':
return_decl = return_decl[1:-1]
multiple_args = len(return_decl.split(', ')) > 1
for arg_idx, arg in enumerate(return_decl.split(', ')):
t, name, default, nullable, size, annotation = type_argument_translations(arg)
# name of arguments and name of return sometimes have collision
# in this case, we rename the return name to <name>_return.
return_name = name
if name in func_decl['func'].split('->')[0]:
return_name = name + "_return"
argument_dict = {'type': t, 'name': return_name, 'annotation': annotation}
if name:
# See Note [field_name versus name]
argument_dict['field_name'] = name
else:
if t == "Tensor" and inplace:
assert annotation and annotation.endswith("!"), \
"Return Tensor of function \"{}\" flagged as inplace needs to be " \
"annotated as mutable".format(func_decl['func'])
argument_dict['name'] = 'self'
else:
argument_dict['name'] = 'result' if not multiple_args else 'result' + str(arg_idx)
argument_dict['output'] = True
arguments.append(argument_dict)
return arguments
def parse_native_yaml(path):
with open(path, 'r') as f:
return yaml.load(f, Loader=Loader)
def propagate_field_names(output_arguments, return_arguments):
if output_arguments:
for i, r in enumerate(return_arguments):
if 'field_name' in r:
output_arguments[i]['field_name'] = r['field_name']
def run(paths):
declarations = []
for path in paths:
for func in parse_native_yaml(path):
declaration = {'mode': 'native'}
try:
declaration['schema_string'] = "aten::" + func['func']
if '->' in func['func']:
func_decl, return_decl = [x.strip() for x in func['func'].split('->')]
else:
raise Exception('Expected return declaration')
fn_name, arguments = func_decl.split('(', 1)
assert arguments[-1] == ")", "Expecting closing ) for {}".format(func['func'])
arguments = arguments[:-1] # Expect closing )
declaration['name'] = func.get('name', fn_name)
declaration['inplace'] = re.search('(^__i|[^_]_$)', fn_name) is not None
return_arguments = parse_return_arguments(return_decl, declaration['inplace'], func)
arguments = parse_arguments(arguments, func.get('variants', []), declaration, return_arguments)
output_arguments = [x for x in arguments if x.get('output')]
propagate_field_names(output_arguments, return_arguments)
declaration['return'] = return_arguments if len(output_arguments) == 0 else output_arguments
declaration['variants'] = func.get('variants', ['function'])
declaration['requires_tensor'] = func.get('requires_tensor', False)
declaration['matches_jit_signature'] = func.get('matches_jit_signature', True)
declaration['cpu_half'] = func.get('cpu_half', False)
declaration['cpu_bool'] = func.get('cpu_bool', False)
declaration['cuda_bool'] = func.get('cuda_bool', False)
declaration['deprecated'] = func.get('deprecated', False)
declaration['device_guard'] = func.get('device_guard', True)
declaration['arguments'] = func.get('arguments', arguments)
declaration['type_method_definition_dispatch'] = func.get('dispatch', declaration['name'])
declaration['python_module'] = func.get('python_module', '')
declarations.append(declaration)
except Exception as e:
msg = '''Exception raised in processing function:
{func}
Generated partial declaration:
{decl}'''.format(func=pprint.pformat(func), decl=pprint.pformat(declaration))
print(msg, file=sys.stderr)
raise e
return declarations