diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index f784f7b49aac..3b22ffc71173 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -334,13 +334,9 @@ inline Array TransformShape(const Array& src_shape, // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 std::unordered_map bind_map; - std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { PrimExpr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; - if (orig_shape.as()) { - symbolic_var_set.insert(i); - } if (!LayoutAxis::Get(orig_axis).IsPrimal()) { if (orig_shape.defined()) { const auto* orig_shape_const = orig_shape.as(); @@ -369,11 +365,7 @@ inline Array TransformShape(const Array& src_shape, if (!LayoutAxis::Get(axis).IsPrimal()) { result.push_back(axis->dom->extent); } else { - if (symbolic_var_set.count(i)) { - result.push_back(tir::Any()); - } else { - result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); - } + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index f602a17e2412..37aa2271a520 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -446,6 +446,15 @@ def test_any_layout_transform(): verify_any_layout_transform((16, 1), "CH", "C4cH", (16, 1), (4, 4, 1)) +def test_bilayout_with_any(): + bilayout = tvm.tir.bijective_layout("NCHW", "NHWC") + assert isinstance(bilayout, tvm.tir.BijectiveLayout) + dst_shape = bilayout.forward_shape((relay.Any(), 32, 7, relay.Any())) + assert dst_shape[3] == 32 + src_shape = bilayout.backward_shape(dst_shape) + assert src_shape[1] == 32 + + def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32"