diff --git a/3rdparty/tvm b/3rdparty/tvm index 001022bdb..e47e76a2a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 001022bdb2dbb337d242eed9d208f8555b8edc98 +Subproject commit e47e76a2a0d565e02b6474c06f9f47e1374821f3 diff --git a/src/transform/lower_pdl.cc b/src/transform/lower_pdl.cc index 4460ad258..13150bf67 100644 --- a/src/transform/lower_pdl.cc +++ b/src/transform/lower_pdl.cc @@ -57,17 +57,10 @@ class MarkCudaSyncCalls : public StmtExprMutator { } PrimExpr VisitExpr_(const tir::CallNode *op) final { - if (op && op->op.same_as(builtin::call_extern())) { - if (!op->args.empty()) { - if (const auto *str_node = op->args[0].as()) { - std::string func_name = str_node->value; - if (func_name == "cudaTriggerProgrammaticLaunchCompletion") { - has_trigger_launch_ = true; - } else if (func_name == "cudaGridDependencySynchronize") { - has_grid_sync_ = true; - } - } - } + if (op->op.same_as(tl::pdl_trigger())) { + has_trigger_launch_ = true; + } else if (op->op.same_as(tl::pdl_sync())) { + has_grid_sync_ = true; } return StmtExprMutator::VisitExpr_(op); } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index dc4cd30de..7d3686abc 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -164,17 +164,8 @@ class WarpSpecializedRoleMarker : public StmtVisitor { if (call->op.same_as(loop_break())) { role = Role::kBoth; } - if (call->op.same_as(builtin::call_extern())) { - if (!call->args.empty()) { - if (const auto *str_node = - call->args[0].as()) { - std::string func_name = str_node->value; - if (func_name == "cudaGridDependencySynchronize" || - func_name == "cudaTriggerProgrammaticLaunchCompletion") { - role = Role::kBoth; - } - } - } + if (call->op.same_as(pdl_sync()) || call->op.same_as(pdl_trigger())) { + role = Role::kBoth; } } SetRole(op, role); diff --git a/testing/python/language/test_tilelang_language_pdl.py b/testing/python/language/test_tilelang_language_pdl.py index 77fe984ea..3f9a3d782 100644 --- a/testing/python/language/test_tilelang_language_pdl.py +++ b/testing/python/language/test_tilelang_language_pdl.py @@ -48,10 +48,10 @@ def test_pdl_trigger(): def test_pdl_sync(): N = 64 program = kernels_with_pdl_sync(N) - pdl_kernel = tilelang.compile(program, target="cuda -arch=sm_90") code = pdl_kernel.get_kernel_source() assert "cudaGridDependencySynchronize" in code + assert "__restrict__" not in code if __name__ == "__main__":