diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 89a803d058e4..607be0a83dce 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -322,6 +322,7 @@ class IterMapRewriter : public ExprMutator { ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " << "IterMapRewriter using DirectMutate. " << "Indirect return occurred in " << input_expr; + return input_expr; } return expr; } diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 63bb79d2b223..1676855b31a2 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1292,5 +1292,25 @@ def test_normalize_to_iter_sum(): ) +def test_detect_iter_map_with_bufferload_recursion(): + n = tvm.tir.Var("n", "int32") + m = tvm.tir.Var("m", "int32") + divisor = tvm.tir.Var("divisor", "int32") + + i = tvm.tir.Var("i", "int32") + j = tvm.tir.Var("j", "int32") + + buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen") + + indices = [(buffer[i] + j) // divisor] + iter_vars = { + i: tvm.ir.Range(tvm.tir.const(0, "int32"), n), + j: tvm.ir.Range(tvm.tir.const(0, "int32"), m), + } + + result = tvm.arith.detect_iter_map(indices, iter_vars) + assert len(result.indices) == 0 + + if __name__ == "__main__": tvm.testing.main()