@@ -49,9 +49,18 @@ class SimplifyConvPad {
4949 conv1d_ = IsOp (" nn.conv1d" );
5050 conv2d_ = IsOp (" nn.conv2d" );
5151 conv3d_ = IsOp (" nn.conv3d" );
52- qconv2d_ = IsOp (" qnn.conv2d" );
53- conv_ = (conv1d_ || conv2d_ || conv3d_ || qconv2d_)({pad_, w_});
54- pattern_ = conv_;
52+
53+ conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
54+
55+ input_zero_point_ = IsWildcard ();
56+ kernel_zero_point_ = IsWildcard ();
57+ input_scale_ = IsWildcard ();
58+ kernel_scale_ = IsWildcard ();
59+
60+ qconv2d_ = IsOp (" qnn.conv2d" )(
61+ {pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});
62+
63+ pattern_ = conv_ || qconv2d_;
5564 }
5665
5766 template <typename T>
@@ -122,26 +131,34 @@ class SimplifyConvPad {
122131 ICHECK (param);
123132 Array<Expr> args = pad_node->args ;
124133
134+ auto x = node_map[x_][0 ];
135+ auto w = node_map[w_][0 ];
136+
125137 // Possibly perform more optimizations if the pad_value is 0
126138 const ConstantNode* pad_value = args[1 ].as <ConstantNode>();
127- if (param->pad_mode == " constant" && pad_value && ToScalar (pad_value->data ) == 0.0 ) {
139+ if (node_map.find (qconv2d_) != node_map.end ()) {
140+ Attrs attrs = GetAttrs (param, call_node->attrs .as <Conv2DAttrs>());
141+ auto input_zero_point = node_map[input_zero_point_][0 ];
142+ auto kernel_zero_point = node_map[kernel_zero_point_][0 ];
143+ auto input_scale = node_map[input_scale_][0 ];
144+ auto kernel_scale = node_map[kernel_scale_][0 ];
145+ return Call (call_node->op ,
146+ {x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
147+ call_node->type_args , call_node->span );
148+ } else if (param->pad_mode == " constant" && pad_value && ToScalar (pad_value->data ) == 0.0 ) {
128149 Attrs attrs;
129150 if (node_map.count (conv1d_)) {
130151 attrs = GetAttrs (param, call_node->attrs .as <Conv1DAttrs>());
131152 } else if (node_map.count (conv2d_)) {
132153 attrs = GetAttrs (param, call_node->attrs .as <Conv2DAttrs>());
133154 } else if (node_map.count (conv3d_)) {
134155 attrs = GetAttrs (param, call_node->attrs .as <Conv3DAttrs>());
135- } else if (node_map.count (qconv2d_)) {
136- attrs = GetAttrs (param, call_node->attrs .as <Conv2DAttrs>());
137156 } else {
138157 return post ;
139158 }
140159 if (!attrs.defined ()) {
141160 return post ;
142161 }
143- auto x = node_map[x_][0 ];
144- auto w = node_map[w_][0 ];
145162 return Call (call_node->op , {x, w}, attrs, call_node->type_args , call_node->span );
146163 }
147164 return post ;
@@ -162,6 +179,10 @@ class SimplifyConvPad {
162179 DFPattern conv2d_;
163180 DFPattern conv3d_;
164181 DFPattern qconv2d_;
182+ DFPattern input_zero_point_;
183+ DFPattern kernel_zero_point_;
184+ DFPattern input_scale_;
185+ DFPattern kernel_scale_;
165186};
166187
167188class SimplifyExplicitPadding {
0 commit comments