forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoperator.h
185 lines (156 loc) · 6.46 KB
/
operator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// in memory description of all ATen Ops similar to Caffe2 schema
// once C10 exists this can be removed, or stubbed out, but we need
// it now to implement correct semantic checking for script
#pragma once
#include <ATen/core/stack.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/script/function_schema_parser.h>
#include <torch/csrc/jit/operator_options.h>
#include <ATen/core/stack.h>
#include <ATen/ATen.h>
#include <ATen/core/function_schema.h>
#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
using ::c10::FunctionSchema;
using OperationCreator = std::function<Operation(const Node*)>;
/*
* Note: JIT relies on Operator instances having static lifetime, because
* it for example stores a non-owning FunctionSchema* pointer in the Node class,
* which points to the function shema stored in the Operator instance.
* Also, jit::Operator is meant to store more operator related information like
* symbolic derivatives, which also requires them to have static lifetime
* so that changes to symbolic derivatives are remembered.
*
* Now, currently, the c10 operator library doesn't store jit::Operator instances,
* but we use a listener pattern that notifies JIT about changes in the
* c10 operator library and then registers jit::Operator instances to the JIT
* operator registry, acting as wrappers to the c10 operators.
*
* However, that results in code duplication as JIT and c10 will likely get
* their own mechanisms for storing derivatives and other operator related
* information, and all of this would have to be wrapped from c10 into JIT.
*
* We should consider merging the JIT and c10 registries, moving jit::Operator
* to c10 and storing these jit::Operator instances in the c10 operator library
* instead, allowing us to have these mechanisms only implemented once.
* However, the current jit::Operator implementation has additional features
* like OperationCreator that aren't needed in c10 (they're only used for
* prim ops like If/Else or While which wouldn't be in the c10 operator library),
* and which depend on other JIT features which we don't want to move to c10
* (notably jit/ir.h). We might, however, be able, to split jit::Operator into
* a c10::Operator with the core features and a jit::Operator that adds the
* JIT-only features like OperationCreator, and then use c10::Operator in the
* c10 operator library.
*/
struct TORCH_API Operator {
Operator(
FunctionSchema schema,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
Operator(
const std::string& schema,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
: schema_string_(schema),
op_creator_(std::move(op_creator)),
options_(std::move(options)) {}
// Helper constructor to register `op` to run
// run for _every_ IR Node where n.kind() == name, regardless of arguments.
// This is accomplished by marking the schema varargs and having no required
// arguments. This is used for things like prim::While or prim::If that can
// take a number of different valid input types and lengths.
Operator(
Symbol name,
OperationCreator op_creator,
OperatorOptions options = OperatorOptions())
: Operator(
FunctionSchema(
name,
"",
{},
{},
/*is_vararg*/ true,
/*is_varret*/ true),
std::move(op_creator),
std::move(options)) {}
Operator(
FunctionSchema schema,
Operation op,
OperatorOptions options = OperatorOptions())
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
Operator(
const std::string& schema,
Operation op,
OperatorOptions options = OperatorOptions())
: schema_string_(schema),
op_(std::make_shared<Operation>(std::move(op))),
options_(std::move(options)) {}
bool matches(const Node* node) const;
Operation getOperation(const Node* node = nullptr) const {
if (op_) {
return *op_;
}
AT_ASSERT(node != nullptr);
return op_creator_(node);
}
const FunctionSchema& schema() const {
// we lazily parse schema initialized from strings so that
// we do less work during static operator registration
if (!schema_) {
schema_ =
std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
schema_string_ = c10::nullopt;
}
return *schema_;
}
const OperatorOptions& options() const {
return options_;
}
private:
mutable c10::optional<std::string> schema_string_;
// cannot use c10::optional because windows has issues that require an
// assignment operator to be generated cannot use std::unique_ptr because
// initializer lists of Operators end up copying the Operator
mutable std::shared_ptr<FunctionSchema> schema_;
// Essentially a variant<Operation, OperationCreator>.
// NB: std::function has a default state (where it == nullptr).
std::shared_ptr<Operation> op_;
OperationCreator op_creator_;
OperatorOptions options_;
};
TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
Symbol name);
std::shared_ptr<Operator> findOperatorFor(const Node* node);
const Operator& getOperatorFor(const Node* node);
inline Operation getOperation(const Node* node) {
// note: getOperatorFor ensures that getOperatorFor(node).matches(node) ==
// true so the call to selectVariant is always valid.
return getOperatorFor(node).getOperation(node);
}
TORCH_API std::vector<Symbol> findSimilarOperators(Symbol input_op);
TORCH_API void registerOperator(Operator&& op);
// XXX: this function is meant to be used with string literals only!
Operator& sig(const char* signature_literal);
struct OperatorSet {
OperatorSet(std::initializer_list<const char*> sig_literals);
// XXX: Returns a nullptr if no Operator in the set matches n
Operator* find(const Node* n) const;
private:
std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
};
} // namespace jit
} // namespace torch