diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index e518350dc..b347fe95d 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -216,20 +216,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # warp_specialized pass will pack the if stmt into the block # so we need to lower the opaque block first mod = tilelang.transform.LowerOpaqueBlock()(mod) - mod = tilelang.transform.MergeIfStmt()(mod) if is_hopper(target): mod = tilelang.transform.RewriteWgmmaSync()(mod) - mod = tilelang.transform.InjectFenceProxy()(mod) else: mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) - mod = tilelang.transform.MergeIfStmt()(mod) - if allow_fence_proxy(target=target): - # in hopper device, wgmma is an async proxy - # so we need to inject a fence proxy before it - mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.Simplify()(mod) @@ -278,8 +271,16 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # because the merged allocation site is at the beginning of each device function enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target) mod = tilelang.transform.MergeSharedMemoryAllocations(enable_aggressive_merge=enable_aggressive_merge)(mod) + if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): + mod = tilelang.transform.InjectFenceProxy()(mod) + else: + if allow_fence_proxy(target=target): + # in hopper device, wgmma is an async proxy + # so we need to inject a fence proxy before it + mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod) + mod = tilelang.transform.MergeIfStmt()(mod) # Inject PTX async copy must behind the thread sync pass # as ptx async copy won't be recognized as a valid buffer load mod = tilelang.transform.InjectPTXAsyncCopy()(mod)