2222 * \brief A pass for folding explicit pads into other ops.
2323 */
2424
25+ #include < dmlc/optional.h>
2526#include < tvm/relay/dataflow_matcher.h>
2627#include < tvm/relay/expr.h>
2728#include < tvm/relay/expr_functor.h>
2829#include < tvm/relay/transform.h>
30+ #include < tvm/runtime/data_type.h>
2931#include < tvm/runtime/logging.h>
32+ #include < tvm/tir/op.h>
33+ #include < tvm/topi/nn/pooling.h>
3034
3135#include " ../op/tensor/transform.h"
3236#include " pattern_utils.h"
@@ -35,46 +39,70 @@ namespace tvm {
3539namespace relay {
3640
3741/* !
38- * \brief SimplifyConvPad matches a pad followed by a conv/convtranspose/pool/etc
42+ * \brief SimplifyExplicitPad matches a pad followed by a conv/maxpool/avgpool
3943 * with a pad attribute and merges the padding into the kernel.
4044 */
41- class SimplifyConvPad {
45+ class SimplifyExplicitPad {
4246 public:
4347 DFPattern pattern () const { return pattern_; }
4448
45- SimplifyConvPad () {
49+ SimplifyExplicitPad () {
4650 x_ = IsWildcard ();
47- w_ = IsWildcard ();
4851 pad_ = IsOp (" nn.pad" )({x_, IsWildcard ()});
52+
53+ // pad->conv patterns
54+ w_ = IsWildcard ();
4955 conv1d_ = IsOp (" nn.conv1d" );
5056 conv2d_ = IsOp (" nn.conv2d" );
5157 conv3d_ = IsOp (" nn.conv3d" );
52-
53- conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
58+ contrib_conv2d_nchwc_ = IsOp ( " nn.contrib_conv2d_NCHWc " );
59+ conv_ = (conv1d_ || conv2d_ || conv3d_ || contrib_conv2d_nchwc_ )({pad_, w_});
5460
5561 input_zero_point_ = IsWildcard ();
5662 kernel_zero_point_ = IsWildcard ();
5763 input_scale_ = IsWildcard ();
5864 kernel_scale_ = IsWildcard ();
59-
6065 qconv2d_ = IsOp (" qnn.conv2d" )(
6166 {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});
6267
63- pattern_ = conv_ || qconv2d_;
68+ // pad->pool patterns
69+ avg_pool1d_ = IsOp (" nn.avg_pool1d" );
70+ avg_pool2d_ = IsOp (" nn.avg_pool2d" );
71+ avg_pool3d_ = IsOp (" nn.avg_pool3d" );
72+ max_pool1d_ = IsOp (" nn.max_pool1d" );
73+ max_pool2d_ = IsOp (" nn.max_pool2d" );
74+ max_pool3d_ = IsOp (" nn.max_pool3d" );
75+ max_pool_ = max_pool1d_ || max_pool2d_ || max_pool3d_;
76+ pool_ = (max_pool_ || avg_pool1d_ || avg_pool2d_ || avg_pool3d_)({pad_});
77+
78+ pattern_ = conv_ || qconv2d_ || pool_;
6479 }
6580
6681 template <typename T>
67- Attrs MakeConvAttrs (const T* old_attrs, const Array<PrimExpr> padding) const {
68- ICHECK (old_attrs);
82+ Array<PrimExpr> get_combined_padding (const T* old_attrs, Array<PrimExpr> padding) const {
6983 ICHECK (padding.size () == old_attrs->padding .size ())
7084 << " Number of dimensions to pad and convolution padding attributes should have the same "
7185 " extent" ;
7286
73- auto new_attrs = make_object<T>();
7487 Array<PrimExpr> combined_padding;
7588 for (size_t i = 0 ; i < padding.size (); ++i) {
7689 combined_padding.push_back (padding[i] + old_attrs->padding [i]);
7790 }
91+ return combined_padding;
92+ }
93+
94+ template <typename T>
95+ Attrs MakeConvAttrs (const PadAttrs* param, const T* old_attrs) const {
96+ // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D conv attrs
97+ ICHECK (old_attrs);
98+ ICHECK (param);
99+ auto padding = get_padding (param, old_attrs->data_layout );
100+ if (!padding) {
101+ return Attrs ();
102+ }
103+ auto combined_padding = get_combined_padding (old_attrs, padding.value ());
104+
105+ auto new_attrs = make_object<T>();
78106 new_attrs->strides = old_attrs->strides ;
79107 new_attrs->padding = combined_padding;
80108 new_attrs->dilation = old_attrs->dilation ;
@@ -89,22 +117,85 @@ class SimplifyConvPad {
89117 }
90118
91119 template <typename T>
92- Attrs GetAttrs (const PadAttrs* param, const T* attrs) const {
120+ Attrs MakeConv2D3DAttrs (const PadAttrs* param, const T* old_attrs) const {
121+ // Propagate additional Conv2D- and Conv3D-specific attrs
122+ auto attrs = MakeConvAttrs (param, old_attrs);
123+ if (!attrs.defined ()) {
124+ return Attrs ();
125+ }
126+
127+ T* new_attrs = const_cast <T*>(attrs.template as <T>());
128+ new_attrs->auto_scheduler_rewritten_layout = old_attrs->auto_scheduler_rewritten_layout ;
129+ return attrs;
130+ }
131+
132+ template <typename T>
133+ Attrs MakePoolAttrs (const PadAttrs* param, const T* old_attrs) const {
134+ // Creates attrs from old_attrs with fields shared by 1D, 2D, 3D pool attrs
135+ ICHECK (old_attrs);
93136 ICHECK (param);
94- ICHECK (attrs);
95- ICHECK (attrs->data_layout .size () == param->pad_width .size ())
137+ auto padding = get_padding (param, old_attrs->layout );
138+ if (!padding) {
139+ return Attrs ();
140+ }
141+ auto combined_padding = get_combined_padding (old_attrs, padding.value ());
142+
143+ auto new_attrs = make_object<T>();
144+ new_attrs->pool_size = old_attrs->pool_size ;
145+ new_attrs->strides = old_attrs->strides ;
146+ new_attrs->dilation = old_attrs->dilation ;
147+ new_attrs->padding = combined_padding;
148+ new_attrs->layout = old_attrs->layout ;
149+ new_attrs->out_layout = old_attrs->out_layout ;
150+ new_attrs->ceil_mode = old_attrs->ceil_mode ;
151+ return Attrs (new_attrs);
152+ }
153+
154+ template <typename T>
155+ Attrs MakeAvgPoolAttrs (const PadAttrs* param, const T* old_attrs) const {
156+ // Propagate additional AvgPool-specific attrs
157+ auto attrs = MakePoolAttrs (param, old_attrs);
158+ if (!attrs.defined ()) {
159+ return attrs;
160+ }
161+
162+ T* new_attrs = const_cast <T*>(attrs.template as <T>());
163+ new_attrs->count_include_pad = old_attrs->count_include_pad ;
164+ if (!new_attrs->count_include_pad ) {
165+ // AvgPool's divisor doesn't include padding, so don't fold the explicit pad
166+ // unless all original pad items are 0.
167+ for (IndexExpr pad : old_attrs->padding ) {
168+ const IntImmNode* maybe_int_imm = pad.as <IntImmNode>();
169+ if (!maybe_int_imm || maybe_int_imm->value != 0 ) {
170+ // Return undefined attrs to signal that we don't want to fold explicit pad
171+ return Attrs ();
172+ }
173+ }
174+ // Turn on `count_include_pad` to preserve original pad first, then pool behavior
175+ // where AvgPool's divisor implicitly includes padding.
176+ new_attrs->count_include_pad = true ;
177+ }
178+
179+ return attrs;
180+ }
181+
182+ static const Optional<Array<PrimExpr>> get_padding (const PadAttrs* param,
183+ std::string data_layout) {
184+ // Gets spatial axes padding from the given PadAttrs `param`. If padding
185+ // is non-zero on non-spatial axes, return NullOpt.
186+ ICHECK (param);
187+ ICHECK (data_layout.size () == param->pad_width .size ())
96188 << " Data Layout and padding attributes should have the same extent" ;
97189
98- std::string data_layout = attrs->data_layout ;
99190 std::set<char > image_dims ({' H' , ' W' , ' D' });
100191 Array<PrimExpr> padding;
101192 // If we're padding a non-spatial dimension, don't simplify
102- // Convolution can only pad on spatial axes
193+ // Convolution/Pool can only pad on spatial axes
103194 for (size_t i = 0 ; i < param->pad_width .size (); ++i) {
104195 if (!image_dims.count (data_layout[i])) {
105196 for (size_t j = 0 ; j < param->pad_width [i].size (); ++j) {
106197 if (param->pad_width [i][j] != 0 ) {
107- return Attrs () ;
198+ return NullOpt ;
108199 }
109200 }
110201 }
@@ -116,8 +207,7 @@ class SimplifyConvPad {
116207 }
117208 }
118209 }
119-
120- return MakeConvAttrs (attrs, padding);
210+ return padding;
121211 }
122212
123213 Expr callback (const Expr& pre , const Expr& post ,
@@ -131,40 +221,75 @@ class SimplifyConvPad {
131221 ICHECK (param);
132222
133223 auto x = node_map[x_][0 ];
134- auto w = node_map[w_][0 ];
135224
136- // Possibly perform more optimizations if the pad_value is 0
137225 const Expr& pv = pad_node->args [1 ];
138226 const ConstantNode* pad_value = pv.as <ConstantNode>();
227+ auto pad_scalar = ToScalar (pad_value->data );
228+
139229 if (node_map.find (qconv2d_) != node_map.end ()) {
140- Attrs attrs = GetAttrs (param, call_node->attrs .as <Conv2DAttrs>());
230+ Attrs attrs = MakeConv2D3DAttrs (param, call_node->attrs .as <Conv2DAttrs>());
231+ if (!attrs.defined ()) {
232+ return post ;
233+ }
141234 auto input_zero_point = node_map[input_zero_point_][0 ];
142235 auto kernel_zero_point = node_map[kernel_zero_point_][0 ];
143236 auto input_scale = node_map[input_scale_][0 ];
144237 auto kernel_scale = node_map[kernel_scale_][0 ];
145238 // Fold Padding and QNN Convolution only if pad value == input zero point.
146239 if (IsEqualScalar (input_zero_point, pv)) {
240+ auto w = node_map[w_][0 ];
147241 return Call (call_node->op ,
148242 {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
149243 call_node->type_args , call_node->span );
150- } else {
151- return post ;
152244 }
153- } else if (param->pad_mode == " constant" && pad_value && ToScalar (pad_value->data ) == 0.0 ) {
245+ return post ;
246+ }
247+
248+ if (param->pad_mode == " constant" && pad_value) {
154249 Attrs attrs;
155- if (node_map.count (conv1d_)) {
156- attrs = GetAttrs (param, call_node->attrs .as <Conv1DAttrs>());
157- } else if (node_map.count (conv2d_)) {
158- attrs = GetAttrs (param, call_node->attrs .as <Conv2DAttrs>());
159- } else if (node_map.count (conv3d_)) {
160- attrs = GetAttrs (param, call_node->attrs .as <Conv3DAttrs>());
161- } else {
162- return post ;
250+ if (pad_scalar == 0.0 ) {
251+ // Fold Padding and Conv/AvgPool only if pad_value == 0.
252+ if (node_map.count (conv_)) {
253+ if (node_map.count (conv1d_)) {
254+ attrs = MakeConvAttrs (param, call_node->attrs .as <Conv1DAttrs>());
255+ } else if (node_map.count (conv2d_)) {
256+ attrs = MakeConv2D3DAttrs (param, call_node->attrs .as <Conv2DAttrs>());
257+ } else if (node_map.count (conv3d_)) {
258+ attrs = MakeConv2D3DAttrs (param, call_node->attrs .as <Conv3DAttrs>());
259+ }
260+ if (!attrs.defined ()) {
261+ return post ;
262+ }
263+ auto w = node_map[w_][0 ];
264+ return Call (call_node->op , {x, w}, attrs, call_node->type_args , call_node->span );
265+ } else if (node_map.count (avg_pool1d_)) {
266+ attrs = MakeAvgPoolAttrs (param, call_node->attrs .as <AvgPool1DAttrs>());
267+ } else if (node_map.count (avg_pool2d_)) {
268+ attrs = MakeAvgPoolAttrs (param, call_node->attrs .as <AvgPool2DAttrs>());
269+ } else if (node_map.count (avg_pool3d_)) {
270+ attrs = MakeAvgPoolAttrs (param, call_node->attrs .as <AvgPool3DAttrs>());
271+ }
272+ } else if (node_map.count (max_pool_)) {
273+ // Fold Padding and MaxPool only if pad_value is the min possible value for the dtype
274+ auto min_value = tvm::min_value (tvm::runtime::DataType (pad_value->data ->dtype ));
275+ const FloatImmNode* maybe_min_float = min_value.as <FloatImmNode>();
276+ const IntImmNode* maybe_min_int = min_value.as <IntImmNode>();
277+
278+ if ((maybe_min_float && pad_scalar == maybe_min_float->value ) ||
279+ (maybe_min_int && pad_scalar == maybe_min_int->value )) {
280+ if (node_map.count (max_pool1d_)) {
281+ attrs = MakePoolAttrs (param, call_node->attrs .as <MaxPool1DAttrs>());
282+ } else if (node_map.count (max_pool2d_)) {
283+ attrs = MakePoolAttrs (param, call_node->attrs .as <MaxPool2DAttrs>());
284+ } else if (node_map.count (max_pool3d_)) {
285+ attrs = MakePoolAttrs (param, call_node->attrs .as <MaxPool3DAttrs>());
286+ }
287+ }
163288 }
164289 if (!attrs.defined ()) {
165290 return post ;
166291 }
167- return Call (call_node->op , {x, w }, attrs, call_node->type_args , call_node->span );
292+ return Call (call_node->op , {x}, attrs, call_node->type_args , call_node->span );
168293 }
169294 return post ;
170295 }
@@ -183,18 +308,27 @@ class SimplifyConvPad {
183308 DFPattern conv1d_;
184309 DFPattern conv2d_;
185310 DFPattern conv3d_;
311+ DFPattern contrib_conv2d_nchwc_;
186312 DFPattern qconv2d_;
187313 DFPattern input_zero_point_;
188314 DFPattern kernel_zero_point_;
189315 DFPattern input_scale_;
190316 DFPattern kernel_scale_;
317+ /* ! \brief Pattern pool */
318+ DFPattern pool_;
319+ DFPattern avg_pool1d_;
320+ DFPattern avg_pool2d_;
321+ DFPattern avg_pool3d_;
322+ DFPattern max_pool1d_;
323+ DFPattern max_pool2d_;
324+ DFPattern max_pool3d_;
325+ DFPattern max_pool_;
191326};
192327
193328class SimplifyExplicitPadding {
194329 public:
195330 explicit SimplifyExplicitPadding (IRModule mod) : mod_(mod) {
196- CreateCallback (SimplifyConvPad ());
197- // TODO(mbrookhart): ConvTranspose(Pad(x)), Pool(Pad(x))
331+ CreateCallback (SimplifyExplicitPad ());
198332 }
199333 template <typename T>
200334 void CreateCallback (const T& pattern) {
0 commit comments