Skip to content

Commit 8b583d8

Browse files
committed
Add FlattenAtrousConv transformation
1 parent 9ca2139 commit 8b583d8

File tree

4 files changed

+655
-0
lines changed

4 files changed

+655
-0
lines changed

python/tvm/relay/transform/transform.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,33 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False):
12931293
return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat)
12941294

12951295

1296+
def FlattenAtrousConv():
1297+
# pylint: disable=anomalous-backslash-in-string
1298+
"""
1299+
The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd
1300+
operations:
1301+
1302+
.. code-block:: text
1303+
1304+
x w
1305+
| |
1306+
s2b |
1307+
\\ /
1308+
conv2d
1309+
|
1310+
b2s
1311+
1312+
and convert them into subgraphs with a convolution with the modified "dilation" and
1313+
recalculated "padding" parameters.
1314+
1315+
Returns
1316+
-------
1317+
ret : tvm.transform.Pass
1318+
The registered FlattenAtrousConv pass.
1319+
"""
1320+
return _ffi_api.FlattenAtrousConv()
1321+
1322+
12961323
def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
12971324
"""
12981325
Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version

src/relay/qnn/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
268268
return vals;
269269
}
270270

271+
Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
272+
Expr input_scale, Expr kernel_scale, Array<IndexExpr> strides,
273+
Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
274+
IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout,
275+
String kernel_layout, String out_layout, DataType out_dtype);
276+
271277
} // namespace qnn
272278
} // namespace relay
273279
} // namespace tvm
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/transforms/flatten_atrous_conv.cc
22+
* \brief This transform flattens atrous convolution, which corresponds to the sequence of
23+
* operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd".
24+
*/
25+
26+
#include <tvm/relay/attrs/nn.h>
27+
#include <tvm/relay/dataflow_matcher.h>
28+
#include <tvm/relay/expr.h>
29+
#include <tvm/relay/expr_functor.h>
30+
#include <tvm/relay/qnn/attrs.h>
31+
#include <tvm/relay/transform.h>
32+
#include <tvm/topi/broadcast.h>
33+
34+
#include <array>
35+
#include <set>
36+
#include <unordered_map>
37+
38+
#include "../qnn/utils.h"
39+
#include "pattern_utils.h"
40+
41+
namespace tvm {
42+
namespace relay {
43+
44+
/* Description of FlattenAtrousConv
45+
*
46+
* The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd
47+
* operations:
48+
*
49+
* x w
50+
* | |
51+
* s2b |
52+
* \ /
53+
* conv2d
54+
* |
55+
* b2s
56+
*
57+
* and convert them into subgraphs with a convolution with the modified "dilation" and
58+
* recalculated "padding" parameters.
59+
*/
60+
61+
using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
62+
63+
class FlattenAtrousConvSubgraphMutator {
64+
public:
65+
Expr MutateSubgraph(const Expr& expr) {
66+
try {
67+
const CallNode* b2s_node_ = expr.as<CallNode>();
68+
const CallNode* conv2d_node_ = b2s_node_->args[0].as<CallNode>();
69+
const CallNode* s2b_node_ = conv2d_node_->args[0].as<CallNode>();
70+
71+
ICHECK(b2s_node_ != nullptr);
72+
const auto* b2s_attrs = b2s_node_->attrs.as<BatchToSpaceNDAttrs>();
73+
ICHECK(b2s_attrs != nullptr);
74+
75+
Array<PrimExpr> dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]};
76+
77+
ICHECK(conv2d_node_ != nullptr);
78+
const auto* conv2d_attrs = conv2d_node_->attrs.as<Conv2DAttrs>();
79+
ICHECK(conv2d_attrs != nullptr);
80+
81+
Array<PrimExpr> kernel_shape = conv2d_attrs->kernel_size;
82+
PrimExpr kernel_h = kernel_shape[0];
83+
PrimExpr kernel_w = kernel_shape[1];
84+
85+
ICHECK(s2b_node_ != nullptr);
86+
const auto* s2b_attrs = s2b_node_->attrs.as<SpaceToBatchNDAttrs>();
87+
ICHECK(s2b_attrs != nullptr);
88+
89+
Expr data = s2b_node_->args[0];
90+
ICHECK(conv2d_attrs->data_layout == "NHWC");
91+
Array<PrimExpr> data_shape = transform::InferTypeLocal(data).as<TensorTypeNode>()->shape;
92+
PrimExpr in_h = data_shape[1];
93+
PrimExpr in_w = data_shape[2];
94+
95+
PrimExpr dilation_h = dilation[0];
96+
PrimExpr dilation_w = dilation[1];
97+
98+
PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1;
99+
PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1;
100+
101+
Array<PrimExpr> strides = {1, 1};
102+
PrimExpr stride_h = strides[0];
103+
PrimExpr stride_w = strides[1];
104+
105+
auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d,
106+
PrimExpr stride1d) -> Array<PrimExpr> {
107+
PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d);
108+
PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0);
109+
PrimExpr pad_before = truncdiv(pad, 2);
110+
PrimExpr pad_after = pad - pad_before;
111+
return {pad_before, pad_after};
112+
};
113+
114+
Array<PrimExpr> pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h);
115+
Array<PrimExpr> pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w);
116+
117+
Array<IndexExpr> padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]};
118+
119+
Expr weight = conv2d_node_->args[1];
120+
121+
if (conv2d_node_->op == Op::Get("nn.conv2d")) {
122+
return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups,
123+
conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout,
124+
conv2d_attrs->kernel_layout, conv2d_attrs->out_layout,
125+
conv2d_attrs->out_dtype);
126+
}
127+
128+
if (conv2d_node_->op == Op::Get("qnn.conv2d")) {
129+
Expr input_zero_point = conv2d_node_->args[2];
130+
Expr kernel_zero_point = conv2d_node_->args[3];
131+
Expr input_scale = conv2d_node_->args[4];
132+
Expr kernel_scale = conv2d_node_->args[5];
133+
return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale,
134+
kernel_scale, strides, padding, dilation, conv2d_attrs->groups,
135+
conv2d_attrs->channels, conv2d_attrs->kernel_size,
136+
conv2d_attrs->data_layout, conv2d_attrs->kernel_layout,
137+
conv2d_attrs->out_layout, conv2d_attrs->out_dtype);
138+
}
139+
140+
DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl;
141+
return expr;
142+
} catch (std::exception& e) {
143+
DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with "
144+
<< e.what() << std::endl;
145+
return expr;
146+
}
147+
}
148+
};
149+
150+
class FlattenAtrousConvRewriter : public MixedModeMutator {
151+
protected:
152+
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
153+
if (const CallNode* call_node = post.as<CallNode>()) {
154+
if (ops_[op_iter_].count(call_node->op)) {
155+
++op_iter_;
156+
if (op_iter_ == ops_.size()) {
157+
op_iter_ = 0;
158+
return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post);
159+
}
160+
} else {
161+
op_iter_ = 0;
162+
}
163+
}
164+
return post;
165+
}
166+
167+
private:
168+
size_t op_iter_ = 0;
169+
const std::array<ExprSet, 3> ops_ = {
170+
ExprSet{Op::Get("nn.space_to_batch_nd")},
171+
ExprSet{Op::Get("nn.conv2d"), Op::Get("qnn.conv2d")},
172+
ExprSet{Op::Get("nn.batch_to_space_nd")},
173+
};
174+
};
175+
176+
Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) {
177+
return FlattenAtrousConvRewriter().Mutate(expr);
178+
}
179+
180+
namespace transform {
181+
182+
Pass FlattenAtrousConv() {
183+
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
184+
[=](Function f, IRModule m, PassContext pc) {
185+
return Downcast<Function>(FlattenAtrousConv(f, m));
186+
};
187+
return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv", {"InferType"});
188+
}
189+
190+
TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv").set_body_typed(FlattenAtrousConv);
191+
192+
} // namespace transform
193+
194+
} // namespace relay
195+
} // namespace tvm

0 commit comments

Comments
 (0)