Skip to content

Commit 965490e

Browse files
authored
[Relay] Optimize transform shape (#13519)
* [Relay] Optimize transform shape * add test
1 parent 3482eab commit 965490e

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/tir/ir/data_layout.cc

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,9 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
334334
// for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule,
335335
// e.g., (C * 16 + c) / 32
336336
std::unordered_map<const tir::VarNode*, PrimExpr> bind_map;
337-
std::unordered_set<size_t> symbolic_var_set;
338337
for (size_t i = 0; i < src_shape.size(); ++i) {
339338
PrimExpr orig_shape = src_shape[i];
340339
IterVar orig_axis = src_axis[i];
341-
if (orig_shape.as<tir::AnyNode>()) {
342-
symbolic_var_set.insert(i);
343-
}
344340
if (!LayoutAxis::Get(orig_axis).IsPrimal()) {
345341
if (orig_shape.defined()) {
346342
const auto* orig_shape_const = orig_shape.as<IntImmNode>();
@@ -369,11 +365,7 @@ inline Array<PrimExpr> TransformShape(const Array<PrimExpr>& src_shape,
369365
if (!LayoutAxis::Get(axis).IsPrimal()) {
370366
result.push_back(axis->dom->extent);
371367
} else {
372-
if (symbolic_var_set.count(i)) {
373-
result.push_back(tir::Any());
374-
} else {
375-
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
376-
}
368+
result.push_back(ana.Simplify(tir::Substitute(rule, bind_map)));
377369
}
378370
}
379371

tests/python/relay/test_any.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,15 @@ def test_any_layout_transform():
446446
verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1))
447447

448448

449+
def test_bilayout_with_any():
450+
bilayout = tvm.tir.bijective_layout("NCHW", "NHWC")
451+
assert isinstance(bilayout, tvm.tir.BijectiveLayout)
452+
dst_shape = bilayout.forward_shape((relay.Any(), 32, 7, relay.Any()))
453+
assert dst_shape[3] == 32
454+
src_shape = bilayout.backward_shape(dst_shape)
455+
assert src_shape[1] == 32
456+
457+
449458
def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape):
450459
mod = tvm.IRModule()
451460
dtype = "float32"

0 commit comments

Comments
 (0)