forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathc10_operator.h
217 lines (194 loc) · 9.73 KB
/
c10_operator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#pragma once
#include <ATen/core/function_schema.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/jit/script/function_schema_parser.h>
#include <vector>
namespace caffe2 {
namespace detail {
constexpr const char* PREALLOCATED_OUTPUT_ARGNAME =
"_caffe2_preallocated_outputs";
using _CallCaffe2OpFunc = std::vector<at::Tensor>(
const c10::FunctionSchema& schema,
std::vector<c10::IValue>&& inputs,
std::vector<at::Tensor>&& outputs);
template <class Caffe2Operator>
inline std::vector<at::Tensor> _call_caffe2_op(
const c10::FunctionSchema& schema,
std::vector<c10::IValue>&& inputs,
std::vector<at::Tensor>&& outputs) {
Caffe2Operator op(schema, std::move(inputs), std::move(outputs));
op.Run();
return std::move(op).move_newstyle_outputs();
}
// This function is inline in the hope that compilers optimizing for speed will
// inline it into call_caffe2_op_from_c10, allowing call_op to be inlined and
// avoiding the function pointer indirection, while compilers optimizing for
// binary size will keep it a separate function instead of inlining it into
// a template and will reuse the binary code of this function between ops.
// We measured and confirmed that binary size off the instagram ios app is
// reduced when having _call_caffe2_op_from_c10 separate from the templated
// call_caffe2_op_from_c10.
inline void _call_caffe2_op_from_c10(
c10::Stack* stack,
const c10::FunctionSchema& schema,
_CallCaffe2OpFunc* call_op) {
// precondition: on the stack, there's one IValue for each argument of the
// c10 schema. The last argument is an optional tensor list that
// (if not ivalue::None) contains a preallocated output tensor for each
// operator output.
AT_ASSERT(
schema.arguments().size() != 0 &&
schema.arguments().back().type()->isSubtypeOf(
OptionalType::create(ListType::ofTensors())));
IValue preallocated_outputs = torch::jit::pop(*stack);
const size_t num_outputs = schema.returns().size();
const size_t num_inputs = schema.arguments().size() -
1; // -1 because the last argument is the list of preallocated tensors
std::vector<at::Tensor> outputs;
if (preallocated_outputs.isNone()) {
// either the schema doesn't support preallocated outputs or it does but
// they haven't been passed in. Pass a list of uninitialized tensors to
// the caffe2 operator as preallocated outputs.
outputs.resize(num_outputs);
} else {
AT_ASSERT(preallocated_outputs.isTensorList());
outputs =
std::move(*std::move(preallocated_outputs).toTensorList()).elements();
}
// TODO Avoid vector allocation. One idea would be to keep the std::vector
// instances in the cache.
std::vector<IValue> inputs = torch::jit::pop(*stack, num_inputs);
outputs = (*call_op)(schema, std::move(inputs), std::move(outputs));
for (auto&& output : std::move(outputs)) {
torch::jit::push(*stack, std::move(output));
}
// postcondition: All inputs are cleared from the stack, there's now one
// IValue for each output which holds the result. This
// might reuse one of the preallocated tensors but doesn't have to.
}
template <const c10::FunctionSchema& (*Schema)(), class Caffe2Operator>
void call_caffe2_op_from_c10(
c10::Stack* stack,
c10::KernelCache* cache) { // TODO Pass in correct cache type
_call_caffe2_op_from_c10(stack, Schema(), &_call_caffe2_op<Caffe2Operator>);
}
inline FunctionSchema make_function_schema_for_c10(const char* schema_str) {
c10::FunctionSchema parsed_schema = torch::jit::parseSchema(schema_str);
std::vector<c10::Argument> arguments = parsed_schema.arguments();
arguments.emplace_back(
PREALLOCATED_OUTPUT_ARGNAME,
c10::OptionalType::create(c10::ListType::ofTensors()),
nullopt,
IValue());
return FunctionSchema(
parsed_schema.name(),
parsed_schema.overload_name(),
std::move(arguments),
parsed_schema.returns(),
parsed_schema.is_vararg(),
parsed_schema.is_varret()
);
}
inline std::unique_ptr<c10::KernelCache> noCache() {
return nullptr;
}
}
}
/**
* To register a caffe2 operator caffe2::MyOperator with the c10 dispatcher,
* call:
*
* In caffe2/operators/MyOperator.h:
*
* > C10_DECLARE_CAFFE2_OPERATOR(C10MyOperator) // C10MyOperator is the name
* // used by c10 for this operator
*
* In caffe2/operators/MyOperator.cc
*
* > C10_REGISTER_CAFFE2_OPERATOR_CPU(
* > C10MyOperator,
* > "_caffe2::C10MyOperator(Tensor input1, int argument2, float argument3) -> (Tensor output1, Tensor output2)"
* > caffe2::MyOperator<caffe2::CPUContext> // This is the caffe2 operator
* > // class template
* > )
*
* In caffe2/operators/MyOperator.cu
*
* > C10_REGISTER_CAFFE2_OPERATOR_CUDA(C10MyOperator,
* caffe2::MyOperator<caffe2::CUDAContext>)
*
* Notes:
* - all macros must be defined in the top level namespace, not in namespace
* caffe2.
* - all operators must call C10_DECLARE_CAFFE2_OPERATOR and
* C10_REGISTER_CAFFE2_OPERATOR_CPU.
* - calling C10_REGISTER_CAFFE2_OPERATOR_CUDA is optional and can be omitted if
* you don't want to expose the operator for CUDA operations.
* - caffe2 arguments must come after caffe2 inputs, in other words, any tensor
* inputs must precede any non-tensor inputs.
*
* More complex use cases:
* - If your operator has a variable number of input tensors, make the first (!)
* input an input of type TensorList. There must be no other tensor inputs.
*/
#ifndef C10_MOBILE
#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \
namespace caffe2 { \
namespace _c10_ops { \
CAFFE2_API const FunctionSchema& schema_##OperatorName(); \
} \
}
#define C10_REGISTER_CAFFE2_OPERATOR_CPU( \
OperatorName, OperatorSchema, OperatorClass) \
/* Register the op schema with the c10 dispatcher */ \
namespace caffe2 { \
namespace _c10_ops { \
C10_EXPORT const FunctionSchema& schema_##OperatorName() { \
static const FunctionSchema schema = \
::caffe2::detail::make_function_schema_for_c10(OperatorSchema); \
return schema; \
} \
} \
} \
/* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
static auto registry_##OperatorName##_##__COUNTER__ = \
::c10::RegisterOperators().op( \
::caffe2::_c10_ops::schema_##OperatorName(), \
::c10::kernel( \
&::caffe2::detail::call_caffe2_op_from_c10< \
::caffe2::_c10_ops::schema_##OperatorName, \
OperatorClass>, \
&::caffe2::detail::noCache), \
::c10::dispatchKey(::c10::CPUTensorId()));
#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass) \
/* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
static auto registry_##OperatorName##_##__COUNTER__ = \
::c10::RegisterOperators().op( \
::caffe2::_c10_ops::schema_##OperatorName(), \
::c10::kernel( \
&::caffe2::detail::call_caffe2_op_from_c10< \
::caffe2::_c10_ops::schema_##OperatorName, \
OperatorClass>, \
&::caffe2::detail::noCache), \
::c10::dispatchKey(::c10::CUDATensorId()));
// You should never manually call the C10_REGISTER_CAFFE2_OPERATOR_HIP macro.
// The C10_REGISTER_CAFFE2_OPERATOR_CUDA macro from above will be automatically
// rewritten to C10_REGISTER_CAFFE2_OPERATOR_HIP by hipify.
#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass) \
/* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
static auto registry_##OperatorName##_##__COUNTER__ = \
::c10::RegisterOperators().op( \
::caffe2::_c10_ops::schema_##OperatorName(), \
::c10::kernel( \
&::caffe2::detail::call_caffe2_op_from_c10< \
::caffe2::_c10_ops::schema_##OperatorName, \
OperatorClass>, \
&::caffe2::detail::noCache), \
::c10::dispatchKey(::c10::HIPTensorId()));
#else
// Don't use c10 dispatcher on mobile because of binary size
#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName)
#define C10_REGISTER_CAFFE2_OPERATOR_CPU(OperatorName, OperatorSchema, OperatorClass)
#define C10_REGISTER_CAFFE2_OPERATOR_CUDA(OperatorName, OperatorClass)
#define C10_REGISTER_CAFFE2_OPERATOR_HIP(OperatorName, OperatorClass)
#endif