Skip to content

Commit 7086bdb

Browse files
committed
works
1 parent db34397 commit 7086bdb

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

src/relay/transforms/fold_explicit_padding.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,18 @@ class SimplifyConvPad {
4949
conv1d_ = IsOp("nn.conv1d");
5050
conv2d_ = IsOp("nn.conv2d");
5151
conv3d_ = IsOp("nn.conv3d");
52-
qconv2d_ = IsOp("qnn.conv2d");
53-
conv_ = (conv1d_ || conv2d_ || conv3d_ || qconv2d_)({pad_, w_});
54-
pattern_ = conv_;
52+
53+
conv_ = (conv1d_ || conv2d_ || conv3d_)({pad_, w_});
54+
55+
input_zero_point_ = IsWildcard();
56+
kernel_zero_point_ = IsWildcard();
57+
input_scale_ = IsWildcard();
58+
kernel_scale_ = IsWildcard();
59+
60+
qconv2d_ = IsOp("qnn.conv2d")(
61+
{pad_, w_, input_zero_point_, kernel_zero_point_, input_scale_, kernel_scale_});
62+
63+
pattern_ = conv_ || qconv2d_;
5564
}
5665

5766
template <typename T>
@@ -122,26 +131,34 @@ class SimplifyConvPad {
122131
ICHECK(param);
123132
Array<Expr> args = pad_node->args;
124133

134+
auto x = node_map[x_][0];
135+
auto w = node_map[w_][0];
136+
125137
// Possibly perform more optimizations if the pad_value is 0
126138
const ConstantNode* pad_value = args[1].as<ConstantNode>();
127-
if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
139+
if (node_map.find(qconv2d_) != node_map.end()) {
140+
Attrs attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
141+
auto input_zero_point = node_map[input_zero_point_][0];
142+
auto kernel_zero_point = node_map[kernel_zero_point_][0];
143+
auto input_scale = node_map[input_scale_][0];
144+
auto kernel_scale = node_map[kernel_scale_][0];
145+
return Call(call_node->op,
146+
{x, w, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, attrs,
147+
call_node->type_args, call_node->span);
148+
} else if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
128149
Attrs attrs;
129150
if (node_map.count(conv1d_)) {
130151
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
131152
} else if (node_map.count(conv2d_)) {
132153
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
133154
} else if (node_map.count(conv3d_)) {
134155
attrs = GetAttrs(param, call_node->attrs.as<Conv3DAttrs>());
135-
} else if (node_map.count(qconv2d_)) {
136-
attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
137156
} else {
138157
return post;
139158
}
140159
if (!attrs.defined()) {
141160
return post;
142161
}
143-
auto x = node_map[x_][0];
144-
auto w = node_map[w_][0];
145162
return Call(call_node->op, {x, w}, attrs, call_node->type_args, call_node->span);
146163
}
147164
return post;
@@ -162,6 +179,10 @@ class SimplifyConvPad {
162179
DFPattern conv2d_;
163180
DFPattern conv3d_;
164181
DFPattern qconv2d_;
182+
DFPattern input_zero_point_;
183+
DFPattern kernel_zero_point_;
184+
DFPattern input_scale_;
185+
DFPattern kernel_scale_;
165186
};
166187

167188
class SimplifyExplicitPadding {

0 commit comments

Comments
 (0)