diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 719703a3ec84..b67a638dd6af 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -348,7 +348,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } bool IsStatic(const Expr& expr, std::vector* vars_collector = nullptr) { - if (expr->IsInstance() || expr->IsInstance()) { + if (expr->IsInstance() || expr->IsInstance() || + expr->IsInstance()) { return true; } if (const auto* prim_value = expr.as()) { diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 73aaf4dac539..dc115939a7e4 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -18,9 +18,11 @@ import pytest import tvm -from tvm import relax -from tvm.script import tir as T, relax as R, ir as I import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T class BaseCompare(tvm.testing.CompareBeforeAfter): @@ -704,5 +706,56 @@ def main(): tvm.ir.assert_structural_equal(Before, AfterWhenDisabled) +def test_static_args(): + @I.ir_module + class Before: + @R.function + def main(): + storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") + alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32") + _ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string")) + return R.tuple() + + @I.ir_module + class Expected: + @R.function(private=True) + def cuda_graph_alloc() -> R.Tuple(R.Object): + R.func_attr({"relax.force_pure": True}) + storage0: R.Object = R.memory.alloc_storage( + R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + gv: R.Tuple(R.Object) = (storage0,) + return gv + + @R.function(private=True) + def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + _: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string")) + gv: R.Tuple = R.tuple() + return gv + + @R.function + def main() -> R.Tuple: + cls = Expected + gv: R.Tuple(R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object),), + ) + storage0: R.Object = gv[0] + alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage0, R.prim_value(0), R.shape([8]), R.dtype("float32") + ) + gv1: R.Tuple = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)), + sinfo_args=(R.Tuple,), + ) + return R.tuple() + + mod = relax.transform.RewriteCUDAGraph()(Before) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()