-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPackedParams.h
147 lines (133 loc) · 4.63 KB
/
PackedParams.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/core/ivalue.h>
struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
virtual at::Tensor apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) = 0;
// out variant of LinearPackedParamsBase::apply
virtual at::Tensor& apply_out(
const at::Tensor& /*input*/,
double /*output_scale*/,
int64_t /*output_zero_point*/,
at::Tensor& output) {
throw std::runtime_error(
"apply_out is not implemented for this packed "
"parameter type");
return output;
}
virtual at::Tensor& apply_relu_out(
const at::Tensor& /*input*/,
double /*output_scale*/,
int64_t /*output_zero_point*/,
at::Tensor& output) {
throw std::runtime_error(
"apply_relu_out is not implemented for this packed "
"parameter type");
return output;
}
// Corresponding pattern (the ops with `*` are part of the pattern that
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
// input -> q* -> dq* -> linear* ->
// qweight -> dq* /
//
// After fusion:
// input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
// qweight /
//
// Additional Note: the weight is packed as well
// Params:
// X: float32 Tensor, will be quantized to quint8 in the op
// W_prepack: packed qint8 quantized weight and bias
// Returns:
// Y: float32 Tensor
virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
at::Tensor input,
double input_scale,
int64_t input_zero_point) {
throw std::runtime_error(
"apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
"parameter type");
return {};
}
// Corresponding pattern (the ops with `*` are part of the pattern that
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
// input -> q* -> dq* -> linear* -> relu* ->
// qweight -> dq* /
//
// After fusion:
// input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
// qweight /
//
// Additional Note: the weight is packed as well
// Params:
// input: float32 Tensor, will be quantized to quint8 in the op
// Returns:
// float32 Tensor
virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
at::Tensor input,
double input_scale,
int64_t input_zero_point) {
throw std::runtime_error(
"apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
"parameter type");
return {};
}
virtual at::Tensor apply_dynamic(
at::Tensor input,
bool reduce_range = false) = 0;
virtual at::Tensor apply_dynamic_relu(
at::Tensor input,
bool reduce_range = false) = 0;
virtual at::Tensor& apply_dynamic_out(
const at::Tensor& /* input */,
at::Tensor& output,
bool /* reduce_range */) {
throw std::runtime_error(
"apply_dynamic_out is not implemented for this packed "
"parameter type");
return output;
}
virtual at::Tensor& apply_dynamic_relu_out(
const at::Tensor& /* input */,
at::Tensor& output,
bool /* reduce_range */) {
throw std::runtime_error(
"apply_dynamic_relu_out is not implemented for this packed "
"parameter type");
return output;
}
virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
virtual c10::optional<at::Tensor> bias() = 0;
virtual void set_bias(c10::optional<at::Tensor> /*bias*/) {
throw std::runtime_error(
"set_bias is not implemented for this packed "
"parameter type");
}
};
template <int kSpatialDim = 2>
struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
virtual at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) = 0;
virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
virtual torch::List<int64_t> stride() const = 0;
virtual torch::List<int64_t> padding() const = 0;
virtual torch::List<int64_t> output_padding() const = 0;
virtual torch::List<int64_t> dilation() const = 0;
virtual int64_t groups() const = 0;
virtual bool transpose() const = 0;
};