Skip to content

Commit 45bed88

Browse files
authored
[Pass] Add MaxPool, AvgPool to FoldExplicitPadding (#11494)
* fold first steps * spitballing * check pad is really optd away * new pool test passes * stuff * refactoring midway * things actually kinda work * complete tests * lint and complete tests * clean * fix comments
1 parent 2b0e082 commit 45bed88

File tree

3 files changed

+351
-49
lines changed

3 files changed

+351
-49
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -891,10 +891,9 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
891891
.set_default(Array<IndexExpr>({0}))
892892
.describe(
893893
"If padding is non-zero, then the input is implicitly zero-padded"
894-
"Padding support both symmetric and asymmetric as"
895-
"one int : same padding used on all sides"
896-
"three int : back, bottom, right will use same padding as front, top, left"
897-
"six int : padding width in the order of (front, top, left, back, bottom, right)");
894+
"Padding supports both symmetric and asymmetric as"
895+
"one int : same padding used on each side"
896+
"two int : indicates left padding, right padding");
898897
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
899898
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
900899
"'N', 'C', 'W' stands for batch, channel, and width"
@@ -933,10 +932,9 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
933932
.set_default(Array<IndexExpr>({0}))
934933
.describe(
935934
"If padding is non-zero, then the input is implicitly zero-padded"
936-
"Padding support both symmetric and asymmetric as"
937-
"one int : same padding used on all sides"
938-
"three int : back, bottom, right will use same padding as front, top, left"
939-
"six int : padding width in the order of (front, top, left, back, bottom, right)");
935+
"Padding supports both symmetric and asymmetric as"
936+
"one int : same padding used on each side"
937+
"two int : indicates left padding, right padding");
940938
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
941939
"Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
942940
"'N', 'C', 'W' stands for batch, channel, and width"

src/relay/transforms/fold_explicit_padding.cc

Lines changed: 170 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
* \brief A pass for folding explicit pads into other ops.
2323
*/
2424

25+
#include <dmlc/optional.h>
2526
#include <tvm/relay/dataflow_matcher.h>
2627
#include <tvm/relay/expr.h>
2728
#include <tvm/relay/expr_functor.h>
2829
#include <tvm/relay/transform.h>
30+
#include <tvm/runtime/data_type.h>
2931
#include <tvm/runtime/logging.h>
32+
#include <tvm/tir/op.h>
33+
#include <tvm/topi/nn/pooling.h>
3034

3135
#include "../op/tensor/transform.h"
3236
#include "pattern_utils.h"
@@ -35,46 +39,70 @@ namespace tvm {
3539
namespace relay {
3640

3741
/*!
38-
* \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
42+
* \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool
3943
* with a pad attribute and merges the padding into the kernel.
4044
*/
41-
class SimplifyConvPad {
45+
class SimplifyExplicitPad {
4246
public:
4347
DFPattern pattern() const { return pattern_; }
4448

45-
SimplifyConvPad() {
49+
SimplifyExplicitPad() {
4650
x_ = IsWildcard();
47-
w_ = IsWildcard();
4851
pad_ = IsOp("nn.pad")({x_, IsWildcard()});
52+
53+
// pad->conv patterns
54+
w_ = IsWildcard();
4955
conv1d_ = IsOp("nn.conv1d");
5056
conv2d_ = IsOp("nn.conv2d");
5157
conv3d_ = IsOp("nn.conv3d");
52-
53-
conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
58+
contrib_conv2d_nchwc_ = IsOp("nn.contrib_conv2d_NCHWc");
59+
conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_)({pad_, w_});
5460

5561
input_zero_point_ = IsWildcard();
5662
kernel_zero_point_ = IsWildcard();
5763
input_scale_ = IsWildcard();
5864
kernel_scale_ = IsWildcard();
59-
6065
qconv2d_ = IsOp("qnn.conv2d")(
6166
{pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});
6267

63-
pattern_ = conv_ || qconv2d_;
68+
// pad->pool patterns
69+
avg_pool1d_ = IsOp("nn.avg_pool1d");
70+
avg_pool2d_ = IsOp("nn.avg_pool2d");
71+
avg_pool3d_ = IsOp("nn.avg_pool3d");
72+
max_pool1d_ = IsOp("nn.max_pool1d");
73+
max_pool2d_ = IsOp("nn.max_pool2d");
74+
max_pool3d_ = IsOp("nn.max_pool3d");
75+
max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_;
76+
pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_});
77+
78+
pattern_ = conv_ || qconv2d_ || pool_;
6479
}
6580

6681
template <typename T>
67-
Attrs MakeConvAttrs(const T* old_attrs, const Array<PrimExpr> padding) const {
68-
ICHECK(old_attrs);
82+
Array<PrimExpr> get_combined_padding(const T* old_attrs, Array<PrimExpr> padding) const {
6983
ICHECK(padding.size() == old_attrs->padding.size())
7084
<< "Number of dimensions to pad and convolution padding attributes should have the same "
7185
"extent";
7286

73-
auto new_attrs = make_object<T>();
7487
Array<PrimExpr> combined_padding;
7588
for (size_t i = 0; i < padding.size(); ++i) {
7689
combined_padding.push_back(padding[i] + old_attrs->padding[i]);
7790
}
91+
return combined_padding;
92+
}
93+
94+
template <typename T>
95+
Attrs MakeConvAttrs(const PadAttrs* param, const T* old_attrs) const {
96+
// Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs
97+
ICHECK(old_attrs);
98+
ICHECK(param);
99+
auto padding = get_padding(param, old_attrs->data_layout);
100+
if (!padding) {
101+
return Attrs();
102+
}
103+
auto combined_padding = get_combined_padding(old_attrs, padding.value());
104+
105+
auto new_attrs = make_object<T>();
78106
new_attrs->strides = old_attrs->strides;
79107
new_attrs->padding = combined_padding;
80108
new_attrs->dilation = old_attrs->dilation;
@@ -89,22 +117,85 @@ class SimplifyConvPad {
89117
}
90118

91119
template <typename T>
92-
Attrs GetAttrs(const PadAttrs* param, const T* attrs) const {
120+
Attrs MakeConv2D3DAttrs(const PadAttrs* param, const T* old_attrs) const {
121+
// Propagate additional Conv2D- and Conv3D-specific attrs
122+
auto attrs = MakeConvAttrs(param, old_attrs);
123+
if (!attrs.defined()) {
124+
return Attrs();
125+
}
126+
127+
T* new_attrs = const_cast<T*>(attrs.template as<T>());
128+
new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout;
129+
return attrs;
130+
}
131+
132+
template <typename T>
133+
Attrs MakePoolAttrs(const PadAttrs* param, const T* old_attrs) const {
134+
// Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs
135+
ICHECK(old_attrs);
93136
ICHECK(param);
94-
ICHECK(attrs);
95-
ICHECK(attrs->data_layout.size() == param->pad_width.size())
137+
auto padding = get_padding(param, old_attrs->layout);
138+
if (!padding) {
139+
return Attrs();
140+
}
141+
auto combined_padding = get_combined_padding(old_attrs, padding.value());
142+
143+
auto new_attrs = make_object<T>();
144+
new_attrs->pool_size = old_attrs->pool_size;
145+
new_attrs->strides = old_attrs->strides;
146+
new_attrs->dilation = old_attrs->dilation;
147+
new_attrs->padding = combined_padding;
148+
new_attrs->layout = old_attrs->layout;
149+
new_attrs->out_layout = old_attrs->out_layout;
150+
new_attrs->ceil_mode = old_attrs->ceil_mode;
151+
return Attrs(new_attrs);
152+
}
153+
154+
template <typename T>
155+
Attrs MakeAvgPoolAttrs(const PadAttrs* param, const T* old_attrs) const {
156+
// Propagate additional AvgPool-specific attrs
157+
auto attrs = MakePoolAttrs(param, old_attrs);
158+
if (!attrs.defined()) {
159+
return attrs;
160+
}
161+
162+
T* new_attrs = const_cast<T*>(attrs.template as<T>());
163+
new_attrs->count_include_pad = old_attrs->count_include_pad;
164+
if (!new_attrs->count_include_pad) {
165+
// AvgPool's divisor doesn't include padding, so don't fold the explicit pad
166+
// unless all original pad items are 0.
167+
for (IndexExpr pad : old_attrs->padding) {
168+
const IntImmNode* maybe_int_imm = pad.as<IntImmNode>();
169+
if (!maybe_int_imm || maybe_int_imm->value != 0) {
170+
// Return undefined attrs to signal that we don't want to fold explicit pad
171+
return Attrs();
172+
}
173+
}
174+
// Turn on `count_include_pad` to preserve original pad first, then pool behavior
175+
// where AvgPool's divisor implicitly includes padding.
176+
new_attrs->count_include_pad = true;
177+
}
178+
179+
return attrs;
180+
}
181+
182+
static const Optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
183+
std::string data_layout) {
184+
// Gets spatial axes padding from the given PadAttrs `param`. If padding
185+
// is non-zero on non-spatial axes, return NullOpt.
186+
ICHECK(param);
187+
ICHECK(data_layout.size() == param->pad_width.size())
96188
<< "Data Layout and padding attributes should have the same extent";
97189

98-
std::string data_layout = attrs->data_layout;
99190
std::set<char> image_dims({'H', 'W', 'D'});
100191
Array<PrimExpr> padding;
101192
// If we're padding a non-spatial dimension, don't simplify
102-
// Convolution can only pad on spatial axes
193+
// Convolution/Pool can only pad on spatial axes
103194
for (size_t i = 0; i < param->pad_width.size(); ++i) {
104195
if (!image_dims.count(data_layout[i])) {
105196
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
106197
if (param->pad_width[i][j] != 0) {
107-
return Attrs();
198+
return NullOpt;
108199
}
109200
}
110201
}
@@ -116,8 +207,7 @@ class SimplifyConvPad {
116207
}
117208
}
118209
}
119-
120-
return MakeConvAttrs(attrs, padding);
210+
return padding;
121211
}
122212

123213
Expr callback(const Expr& pre, const Expr& post,
@@ -131,40 +221,75 @@ class SimplifyConvPad {
131221
ICHECK(param);
132222

133223
auto x = node_map[x_][0];
134-
auto w = node_map[w_][0];
135224

136-
// Possibly perform more optimizations if the pad_value is 0
137225
const Expr& pv = pad_node->args[1];
138226
const ConstantNode* pad_value = pv.as<ConstantNode>();
227+
auto pad_scalar = ToScalar(pad_value->data);
228+
139229
if (node_map.find(qconv2d_) != node_map.end()) {
140-
Attrs attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
230+
Attrs attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
231+
if (!attrs.defined()) {
232+
return post;
233+
}
141234
auto input_zero_point = node_map[input_zero_point_][0];
142235
auto kernel_zero_point = node_map[kernel_zero_point_][0];
143236
auto input_scale = node_map[input_scale_][0];
144237
auto kernel_scale = node_map[kernel_scale_][0];
145238
// Fold Padding and QNN Convolution only if pad value == input zero point.
146239
if (IsEqualScalar(input_zero_point, pv)) {
240+
auto w = node_map[w_][0];
147241
return Call(call_node->op,
148242
{x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
149243
call_node->type_args, call_node->span);
150-
} else {
151-
return post;
152244
}
153-
} else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
245+
return post;
246+
}
247+
248+
if (param->pad_mode == "constant" && pad_value) {
154249
Attrs attrs;
155-
if (node_map.count(conv1d_)) {
156-
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
157-
} else if (node_map.count(conv2d_)) {
158-
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
159-
} else if (node_map.count(conv3d_)) {
160-
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
161-
} else {
162-
return post;
250+
if (pad_scalar == 0.0) {
251+
// Fold Padding and Conv/AvgPool only if pad_value == 0.
252+
if (node_map.count(conv_)) {
253+
if (node_map.count(conv1d_)) {
254+
attrs = MakeConvAttrs(param, call_node->attrs.as<Conv1DAttrs>());
255+
} else if (node_map.count(conv2d_)) {
256+
attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv2DAttrs>());
257+
} else if (node_map.count(conv3d_)) {
258+
attrs = MakeConv2D3DAttrs(param, call_node->attrs.as<Conv3DAttrs>());
259+
}
260+
if (!attrs.defined()) {
261+
return post;
262+
}
263+
auto w = node_map[w_][0];
264+
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
265+
} else if (node_map.count(avg_pool1d_)) {
266+
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool1DAttrs>());
267+
} else if (node_map.count(avg_pool2d_)) {
268+
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool2DAttrs>());
269+
} else if (node_map.count(avg_pool3d_)) {
270+
attrs = MakeAvgPoolAttrs(param, call_node->attrs.as<AvgPool3DAttrs>());
271+
}
272+
} else if (node_map.count(max_pool_)) {
273+
// Fold Padding and MaxPool only if pad_value is the min possible value for the dtype
274+
auto min_value = tvm::min_value(tvm::runtime::DataType(pad_value->data->dtype));
275+
const FloatImmNode* maybe_min_float = min_value.as<FloatImmNode>();
276+
const IntImmNode* maybe_min_int = min_value.as<IntImmNode>();
277+
278+
if ((maybe_min_float && pad_scalar == maybe_min_float->value) ||
279+
(maybe_min_int && pad_scalar == maybe_min_int->value)) {
280+
if (node_map.count(max_pool1d_)) {
281+
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool1DAttrs>());
282+
} else if (node_map.count(max_pool2d_)) {
283+
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool2DAttrs>());
284+
} else if (node_map.count(max_pool3d_)) {
285+
attrs = MakePoolAttrs(param, call_node->attrs.as<MaxPool3DAttrs>());
286+
}
287+
}
163288
}
164289
if (!attrs.defined()) {
165290
return post;
166291
}
167-
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
292+
return Call(call_node->op, {x}, attrs, call_node->type_args, call_node->span);
168293
}
169294
return post;
170295
}
@@ -183,18 +308,27 @@ class SimplifyConvPad {
183308
DFPattern conv1d_;
184309
DFPattern conv2d_;
185310
DFPattern conv3d_;
311+
DFPattern contrib_conv2d_nchwc_;
186312
DFPattern qconv2d_;
187313
DFPattern input_zero_point_;
188314
DFPattern kernel_zero_point_;
189315
DFPattern input_scale_;
190316
DFPattern kernel_scale_;
317+
/*! \brief Pattern pool */
318+
DFPattern pool_;
319+
DFPattern avg_pool1d_;
320+
DFPattern avg_pool2d_;
321+
DFPattern avg_pool3d_;
322+
DFPattern max_pool1d_;
323+
DFPattern max_pool2d_;
324+
DFPattern max_pool3d_;
325+
DFPattern max_pool_;
191326
};
192327

193328
class SimplifyExplicitPadding {
194329
public:
195330
explicit SimplifyExplicitPadding(IRModule mod) : mod_(mod) {
196-
CreateCallback(SimplifyConvPad());
197-
// TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x))
331+
CreateCallback(SimplifyExplicitPad());
198332
}
199333
template <typename T>
200334
void CreateCallback(const T& pattern) {

0 commit comments

Comments
 (0)