diff --git a/test/TritonGPU/amd/amd-convert-warp-pipeline-invalid.mlir b/test/TritonGPU/amd/amd-convert-warp-pipeline-invalid.mlir new file mode 100644 index 000000000000..532cfcd00f23 --- /dev/null +++ b/test/TritonGPU/amd/amd-convert-warp-pipeline-invalid.mlir @@ -0,0 +1,134 @@ +// RUN: triton-opt %s -split-input-file -convert-warp-pipeline="gfx-arch=gfx950" -verify-diagnostics + +// validatePipelinedForBody runs upfront, before any IR mutation, so a +// malformed `pipelined_for` body fails the pass with no partial conversion. + +// ==== Non-warp-pipeline scf.execute_region inside a pipelined_for body ==== + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @bad_unmarked_execute_region(%n: index, %ptr: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + scf.for %i = %c0 to %n step %c1 { + scf.execute_region { + tt.store %ptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0"} + + // expected-error @+1 {{non-warp-pipeline scf.execute_region inside pipelined_for body}} + scf.execute_region { + tt.store %ptr, %v1 : !tt.ptr + scf.yield + } + + scf.yield + } {triton.warp_pipeline.pipelined_for} + + tt.return + } +} + +// ----- + +// ==== Multiple pre-existing barriers between two stages ==== + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @bad_double_barrier_between_stages(%n: index, %ptr: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + scf.for %i = %c0 to %n step %c1 { + scf.execute_region { + tt.store %ptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0"} + + amdg.async_wait {num_inst = 0 : i32} + // expected-error @+1 {{multiple pre-existing barriers between pipeline stages}} + amdg.async_wait {num_inst = 0 : i32} + + scf.execute_region { + tt.store %ptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1"} + + scf.yield + } {triton.warp_pipeline.pipelined_for} + + tt.return + } +} + +// ----- + +// ==== Both top-of-loop and bottom-of-loop pre-existing barriers ==== + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @bad_top_and_bottom_barriers(%n: index, %ptr: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + // expected-error @+1 {{both top-of-loop and bottom-of-loop pre-existing barriers}} + scf.for %i = %c0 to %n step %c1 { + amdg.async_wait {num_inst = 0 : i32} + + scf.execute_region { + tt.store %ptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0"} + + scf.execute_region { + tt.store %ptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1"} + + amdg.async_wait {num_inst = 0 : i32} + + scf.yield + } {triton.warp_pipeline.pipelined_for} + + tt.return + } +} + +// ----- + +// ==== Unexpected op inside a pipelined_for body ==== +// +// Anything that is not a warp-pipeline stage, an ignorable barrier/wait, +// or scf.yield must be rejected upfront. + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @bad_unexpected_op_in_body(%n: index, %ptr: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + scf.for %i = %c0 to %n step %c1 { + scf.execute_region { + tt.store %ptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0"} + + // expected-error @+1 {{unexpected op inside pipelined_for body}} + %x = arith.addi %i, %c1 : index + + scf.execute_region { + tt.store %ptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1"} + + scf.yield + } {triton.warp_pipeline.pipelined_for} + + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-convert-warp-pipeline.mlir b/test/TritonGPU/amd/amd-convert-warp-pipeline.mlir index 222dad43961f..1532c2617aed 100644 --- a/test/TritonGPU/amd/amd-convert-warp-pipeline.mlir +++ b/test/TritonGPU/amd/amd-convert-warp-pipeline.mlir @@ -445,3 +445,936 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // CHECK: rocdl.sched.barrier // CHECK: amdg.cond_barrier // CHECK: tt.return + +// ----- + +// ---- Back-to-back: cross-pipeline LDS dep covered by A's wrap-around ---- +// +// Both loops access the same shared buffer (read + write). Loop 1's +// stage1 writes smem and loop 2's stage0 reads it — a cross-pipeline RAW. +// +// Loop 1's wrap-around barrier (bars[0]) is LOCAL because of the in-loop +// RAW between stage1 (write) and the next iteration's stage0 (read). +// That barrier physically sits at the bottom of loop 1's body and is the +// most recent LDS sync after the loop exits, so it already covers the +// (a1, b0) cross-pipeline dep at the boundary. The boundary barriers +// can therefore be eliminated. +// +// Expected: +// ttg.barrier local (pre-barrier for loop 1) +// amdg.cond_barrier (#1 phase shift for loop 1) +// scf.for { loop 1 } +// NO amdg.cond_barrier (#2 eliminated — wrap-around covers) +// NO ttg.barrier local (prelude eliminated) +// NO amdg.cond_barrier (#3 eliminated) +// scf.for { loop 2 } +// amdg.cond_barrier (#4 post-loop reconverge for loop 2) + +#b2b_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#b2b_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#b2b_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#b2b_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @back_to_back_wrap_around_covers_dep( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x256xf32, #b2b_mma>, + %ptr: tensor<256x64x!tt.ptr, #b2b_blocked>) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + + // Loop 1: stage0 reads LDS, stage1 writes LDS + %r1:2 = scf.for %i = %lb to %ub step %step + iter_args(%a1 = %acc, %s1 = %smem) + -> (tensor<256x256xf32, #b2b_mma>, !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable>) : i32 { + %ld1 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %s1[0, 0] : !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2b_shared, #b2b_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #b2b_shared, #b2b_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "lds_load"} + + %st1 = scf.execute_region -> !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #b2b_blocked> + ttg.local_store %data, %s1 : tensor<256x64xf16, #b2b_blocked> -> !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + scf.yield %s1 : !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + } {triton.warp_pipeline.stage = "global_load_and_store"} + + scf.yield %a1, %st1 : tensor<256x256xf32, #b2b_mma>, !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + // Loop 2: same structure — reads+writes the SAME buffer → cross-pipeline RAW + %r2:2 = scf.for %j = %lb to %ub step %step + iter_args(%a2 = %r1#0, %s2 = %r1#1) + -> (tensor<256x256xf32, #b2b_mma>, !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable>) : i32 { + %ld2 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> no_inline { + %sub2 = ttg.memdesc_subslice %s2[0, 0] : !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2b_shared, #b2b_smem, mutable, 256x64> + %v2 = ttg.local_load %sub2 : !ttg.memdesc<256x16xf16, #b2b_shared, #b2b_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> + scf.yield %v2 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2b_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "epilogue_lds_load"} + + %st2 = scf.execute_region -> !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> no_inline { + %data2 = tt.load %ptr : tensor<256x64x!tt.ptr, #b2b_blocked> + ttg.local_store %data2, %s2 : tensor<256x64xf16, #b2b_blocked> -> !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + scf.yield %s2 : !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + } {triton.warp_pipeline.stage = "epilogue_global_load_and_store"} + + scf.yield %a2, %st2 : tensor<256x256xf32, #b2b_mma>, !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #b2b_shared, #b2b_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @back_to_back_wrap_around_covers_dep +// Pre-barrier and phase shift for loop 1. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Wrap-around barrier inside loop 1 (LOCAL — covers cross-pipeline dep). +// CHECK: ttg.barrier local +// CHECK: scf.yield +// Boundary barriers are eliminated: A's wrap-around already provides the +// LDS sync needed for loop 2's first read; phase carries over. +// CHECK-NOT: amdg.cond_barrier +// CHECK-NOT: ttg.barrier local +// CHECK: scf.for +// Post-loop reconverge for loop 2. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Flat (unrolled) pipeline: execute_regions outside scf.for ---- +// +// Simulates the output of WarpPipeliner::createFlatPipeline — +// 4 execute_regions from a 2-iteration × 2-stage unrolled epilogue. +// ConvertWarpPipeline should insert pre-barrier, phase shift, +// cluster barriers, priority, and reconverge around them. + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @flat_pipeline_backend(%ptr0: !tt.ptr, %ptr1: !tt.ptr) { + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + %v2 = arith.constant 2.0 : f32 + %v3 = arith.constant 3.0 : f32 + + // Iteration 0, stage 0 + scf.execute_region no_inline { + tt.store %ptr0, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0_epi", triton.warp_pipeline.priority = 1 : i32} + + // Iteration 0, stage 1 + scf.execute_region no_inline { + tt.store %ptr1, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1_epi", triton.warp_pipeline.priority = 0 : i32} + + // Iteration 1, stage 0 + scf.execute_region no_inline { + tt.store %ptr0, %v2 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0_epi", triton.warp_pipeline.priority = 1 : i32} + + // Iteration 1, stage 1 + scf.execute_region no_inline { + tt.store %ptr1, %v3 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1_epi", triton.warp_pipeline.priority = 0 : i32} + + tt.return + } +} + +// CHECK-LABEL: tt.func @flat_pipeline_backend +// All execute_regions must be inlined. +// CHECK-NOT: no_inline +// +// Pre-barrier + phase shift. +// CHECK: ttg.barrier local +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne +// CHECK: amdg.cond_barrier %[[WARPHIGH]] +// +// Stage 0 priority. +// CHECK: rocdl.s.setprio 1 +// Stage 0 ops (inlined). +// CHECK: tt.store +// +// Cluster barrier between stages 0 and 1. +// CHECK: rocdl.s.setprio 0 +// CHECK: rocdl.sched.barrier +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier +// Stage 1 ops. +// CHECK: tt.store +// +// Cluster barrier between iteration 0 stage 1 and iteration 1 stage 0. +// CHECK: rocdl.s.setprio 1 +// CHECK: rocdl.sched.barrier +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier +// CHECK: tt.store +// +// Cluster barrier between iteration 1 stages. +// CHECK: rocdl.s.setprio 0 +// CHECK: rocdl.sched.barrier +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier +// CHECK: tt.store +// +// Post-sequence priority reset + reconverge. +// CHECK: rocdl.s.setprio 0 +// CHECK: amdg.cond_barrier %[[WARPLOW]] +// CHECK: tt.return + +// ----- + +// ---- Back-to-back: pipelined scf.for + flat (unrolled) pipeline ---- +// +// Loop 1 (scf.for) followed immediately by a flat pipeline with no +// intervening operations. The post-loop reconverge, prelude barrier, +// and phase shift are all eliminated — same logic as back-to-back +// scf.for loops. +// +// Expected: +// ttg.barrier local (pre-barrier for loop 1) +// amdg.cond_barrier (#1 phase shift for loop 1) +// scf.for { loop 1 } +// NO amdg.cond_barrier (#2 eliminated) +// NO ttg.barrier local (pre-barrier eliminated) +// NO amdg.cond_barrier (#3 eliminated) +// [flat pipeline stages] +// amdg.cond_barrier (#4 reconverge for flat pipeline) + +#b2bf_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#b2bf_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#b2bf_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#b2bf_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @back_to_back_for_then_flat( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x256xf32, #b2bf_mma>, + %ptr: tensor<256x64x!tt.ptr, #b2bf_blocked>, + %sptr: !tt.ptr) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + // Loop 1: local_load + local_store → wrap-around is ttg.barrier local + %r1:2 = scf.for %i = %lb to %ub step %step + iter_args(%a1 = %acc, %s1 = %smem) + -> (tensor<256x256xf32, #b2bf_mma>, !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable>) : i32 { + %ld1 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bf_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %s1[0, 0] : !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2bf_shared, #b2bf_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #b2bf_shared, #b2bf_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bf_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bf_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "lds_load"} + + %st1 = scf.execute_region -> !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #b2bf_blocked> + ttg.local_store %data, %s1 : tensor<256x64xf16, #b2bf_blocked> -> !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> + scf.yield %s1 : !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> + } {triton.warp_pipeline.stage = "global_load_and_store"} + + scf.yield %a1, %st1 : tensor<256x256xf32, #b2bf_mma>, !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + // Flat (unrolled) pipeline: 2 stages, simple stores (no LDS dep) + scf.execute_region no_inline { + tt.store %sptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "flat_stage0"} + + scf.execute_region no_inline { + tt.store %sptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "flat_stage1"} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #b2bf_shared, #b2bf_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @back_to_back_for_then_flat +// Pre-barrier and phase shift for loop 1 are kept. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Wrap-around barrier inside loop 1. +// CHECK: ttg.barrier local +// CHECK: scf.yield +// Between loop 1 and flat pipeline: no cond_barriers, no ttg.barrier local +// (no intervening ops → phase carries over, prelude barrier redundant). +// CHECK-NOT: amdg.cond_barrier +// CHECK-NOT: ttg.barrier local +// Flat pipeline stages (inlined after conversion). +// CHECK: tt.store +// CHECK: rocdl.s.barrier +// CHECK: tt.store +// Reconverge for flat pipeline is kept. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Flat pipeline with pre-existing barrier between stages ---- +// +// When an async_wait (or similar barrier op) already exists between +// flat pipeline stages, the pass should wrap it with sched_barriers +// instead of inserting a redundant s_barrier. +// +// Stage layout: stage0 -- async_wait -- stage1 -- (nothing) -- stage2 +// +// Expected between stage0 and stage1: +// sched_barrier + async_wait + sched_barrier (wrapped, no s_barrier) +// Expected between stage1 and stage2: +// sched_barrier + s_barrier + sched_barrier (inserted, no async_wait) + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func @flat_pipeline_existing_barrier(%ptr: !tt.ptr) { + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + %v2 = arith.constant 2.0 : f32 + + scf.execute_region no_inline { + tt.store %ptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage0"} + + amdg.async_wait {num_inst = 0 : i32} + + scf.execute_region no_inline { + tt.store %ptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage1"} + + scf.execute_region no_inline { + tt.store %ptr, %v2 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "stage2"} + + tt.return + } +} + +// CHECK-LABEL: tt.func @flat_pipeline_existing_barrier +// CHECK-NOT: no_inline +// +// Pre-barrier + phase shift. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// +// Stage 0 ops. +// CHECK: tt.store +// +// Between stage 0 and 1: existing async_wait wrapped, no s_barrier. +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: amdg.async_wait +// CHECK-NEXT: rocdl.sched.barrier +// CHECK-NOT: rocdl.s.barrier +// Stage 1 ops. +// CHECK: tt.store +// +// Between stage 1 and 2: no pre-existing barrier, so s_barrier inserted. +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: rocdl.s.barrier +// CHECK-NEXT: rocdl.sched.barrier +// Stage 2 ops. +// CHECK: tt.store +// +// Reconverge. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Back-to-back: no cross-pipeline LDS dep → barriers eliminated ---- +// +// Loop 1 reads+writes shared memory. Loop 2 only does global ops (no LDS). +// No cross-pipeline LDS dependency exists, so the boundary barriers are +// safely eliminated and the phase carries over. +// +// Expected: +// ttg.barrier local (pre-barrier for loop 1) +// amdg.cond_barrier (#1 phase shift for loop 1) +// scf.for { loop 1 } +// NO amdg.cond_barrier (eliminated) +// NO ttg.barrier local (eliminated) +// NO amdg.cond_barrier (eliminated) +// scf.for { loop 2 } +// amdg.cond_barrier (#4 reconverge for loop 2) + +#b2bnd_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#b2bnd_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#b2bnd_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#b2bnd_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @back_to_back_no_dep_elimination( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x256xf32, #b2bnd_mma>, + %ptr: tensor<256x64x!tt.ptr, #b2bnd_blocked>, + %gptr: !tt.ptr) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + // Loop 1: stage0 reads LDS, stage1 writes LDS + %r1:2 = scf.for %i = %lb to %ub step %step + iter_args(%a1 = %acc, %s1 = %smem) + -> (tensor<256x256xf32, #b2bnd_mma>, !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable>) : i32 { + %ld1 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bnd_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %s1[0, 0] : !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2bnd_shared, #b2bnd_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #b2bnd_shared, #b2bnd_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bnd_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bnd_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "lds_load"} + + %st1 = scf.execute_region -> !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #b2bnd_blocked> + ttg.local_store %data, %s1 : tensor<256x64xf16, #b2bnd_blocked> -> !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> + scf.yield %s1 : !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> + } {triton.warp_pipeline.stage = "global_load_and_store"} + + scf.yield %a1, %st1 : tensor<256x256xf32, #b2bnd_mma>, !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + // Loop 2: global-only ops — no LDS access at all + scf.for %j = %lb to %ub step %step : i32 { + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "global_store_0"} + + scf.execute_region no_inline { + tt.store %gptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "global_store_1"} + + scf.yield + } {triton.warp_pipeline.pipelined_for} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #b2bnd_shared, #b2bnd_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @back_to_back_no_dep_elimination +// Pre-barrier and phase shift for loop 1. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Wrap-around barrier inside loop 1. +// CHECK: ttg.barrier local +// CHECK: scf.yield +// No cross-pipeline LDS dep → barriers eliminated, phase carries over. +// CHECK-NOT: amdg.cond_barrier +// CHECK-NOT: ttg.barrier local +// CHECK: scf.for +// Post-loop reconverge for loop 2. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Back-to-back: cross-pipeline dep covered by loop A's barrier ---- +// +// Loop 1 has 3 stages: stage0 writes LDS, stage1 reads LDS, stage2 is +// compute-only. The circular dependency analysis places a LOCAL barrier +// between stage1 and stage2 (covering the WAR from stage1 reading what +// stage0 wrote). +// +// Loop 2 has 2 stages: stage0 reads the SAME LDS buffer, stage1 is +// compute-only. There IS a cross-pipeline dependency (loop1.stage0 writes +// smem that loop2.stage0 reads), but it is already covered by loop 1's +// barrier between stage1 and stage2. +// +// At the boundary with no barrier: warp0 runs b0, warp1 runs a2. +// Since a2 has no LDS access and the LOCAL barrier before a2 already +// flushed all prior LDS writes, b0's read is safe. +// +// Expected: +// ttg.barrier local (pre-barrier for loop 1) +// amdg.cond_barrier (phase shift for loop 1) +// scf.for { loop 1 — 3 stages } +// NO amdg.cond_barrier (eliminated) +// NO ttg.barrier local (eliminated) +// NO amdg.cond_barrier (eliminated) +// scf.for { loop 2 — 2 stages } +// amdg.cond_barrier (reconverge for loop 2) + +#b2bcov_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#b2bcov_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#b2bcov_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#b2bcov_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @back_to_back_dep_covered_elimination( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x256xf32, #b2bcov_mma>, + %ptr: tensor<256x64x!tt.ptr, #b2bcov_blocked>, + %gptr: !tt.ptr) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + %v0 = arith.constant 0.0 : f32 + + // Loop 1: 3 stages + // stage0: writes LDS (local_store) + // stage1: reads LDS (local_load) → RAW with stage0 + // stage2: compute-only (global store, no LDS) + // Circular analysis: barrier between stage1 and stage2 is LOCAL. + %r1:2 = scf.for %i = %lb to %ub step %step + iter_args(%a1 = %acc, %s1 = %smem) + -> (tensor<256x256xf32, #b2bcov_mma>, !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable>) : i32 { + %st1 = scf.execute_region -> !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #b2bcov_blocked> + ttg.local_store %data, %s1 : tensor<256x64xf16, #b2bcov_blocked> -> !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + scf.yield %s1 : !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + } {triton.warp_pipeline.stage = "global_load_and_store"} + + %ld1 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %s1[0, 0] : !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2bcov_shared, #b2bcov_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #b2bcov_shared, #b2bcov_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "lds_load"} + + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "compute"} + + scf.yield %a1, %s1 : tensor<256x256xf32, #b2bcov_mma>, !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + // Loop 2: stage0 reads the SAME LDS buffer, stage1 is compute-only. + // Cross-pipeline dep (a0 writes → b0 reads) is covered by loop 1's + // barrier between stage1 and stage2. + %r2:2 = scf.for %j = %lb to %ub step %step + iter_args(%a2 = %r1#0, %s2 = %r1#1) + -> (tensor<256x256xf32, #b2bcov_mma>, !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable>) : i32 { + %ld2 = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> no_inline { + %sub2 = ttg.memdesc_subslice %s2[0, 0] : !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> -> !ttg.memdesc<256x16xf16, #b2bcov_shared, #b2bcov_smem, mutable, 256x64> + %v2 = ttg.local_load %sub2 : !ttg.memdesc<256x16xf16, #b2bcov_shared, #b2bcov_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> + scf.yield %v2 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #b2bcov_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "epilogue_lds_load"} + + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "epilogue_compute"} + + scf.yield %a2, %s2 : tensor<256x256xf32, #b2bcov_mma>, !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #b2bcov_shared, #b2bcov_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @back_to_back_dep_covered_elimination +// Pre-barrier and phase shift for loop 1. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Loop 1 has 3 stages; barrier between stage1→stage2 is LOCAL (covers dep). +// CHECK: ttg.barrier local +// CHECK: scf.yield +// Cross-pipeline dep IS covered by loop 1's internal barrier → +// boundary barriers eliminated, phase carries over. +// CHECK-NOT: amdg.cond_barrier +// CHECK-NOT: ttg.barrier local +// CHECK: scf.for +// Post-loop reconverge for loop 2. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Adjacent-stage LDS dependency: barrier must be LOCAL ---- +// +// 3-stage loop pipeline where stage0 writes LDS and stage1 reads it. +// Stage2 has no LDS access. +// +// The distance-2+ analysis only checks pairs separated by ≥2 clusters, +// so it never examines (stage0, stage1) directly. Without the adjacent- +// stage check, the barrier between stage0 and stage1 would be emitted as +// a plain s_barrier, and ModuleMembarAnalysis would later insert a +// redundant ttg.barrier local inside the pipeline — breaking timing. +// +// With the adjacent-stage check: +// bars[0] (wrap-around) = false (a2 no LDS, a0 writes — no conflict) +// bars[1] (a0→a1) = true (a0 writes, a1 reads — RAW) +// bars[2] (a1→a2) = true (a1→a0 WAR via distance-2) +// +// Expected inside the loop body: +// stage0 ops (local_store) +// ttg.barrier local (bars[1] — adjacent dep) +// stage1 ops (local_load) +// ttg.barrier local (bars[2] — distance-2 dep) +// stage2 ops (global store) +// rocdl.s.barrier (bars[0] — wrap-around, no LDS dep) +// scf.yield + +#adj_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#adj_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#adj_dot = #ttg.dot_op<{opIdx = 0, parent = #adj_mma, kWidth = 4}> +#adj_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#adj_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @adjacent_stage_lds_dep( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x16xf16, #adj_dot>, + %ptr: tensor<256x64x!tt.ptr, #adj_blocked>, + %gptr: !tt.ptr) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> + %v0 = arith.constant 0.0 : f32 + + // The local_load result must be carried as an iter_arg so it is not + // DCE'd — otherwise the barrier between stage0 and stage1 would merge + // with the barrier between stage1 and stage2. + %r:3 = scf.for %i = %lb to %ub step %step + iter_args(%a = %acc, %s = %smem, %prev = %acc) + -> (tensor<256x16xf16, #adj_dot>, !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable>, tensor<256x16xf16, #adj_dot>) : i32 { + + // Stage 0: writes LDS + %st = scf.execute_region -> !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #adj_blocked> + ttg.local_store %data, %s : tensor<256x64xf16, #adj_blocked> -> !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> + scf.yield %s : !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> + } {triton.warp_pipeline.stage = "global_load_and_store"} + + // Stage 1: reads LDS — RAW dep with stage 0 + %ld = scf.execute_region -> tensor<256x16xf16, #adj_dot> no_inline { + %sub = ttg.memdesc_subslice %s[0, 0] : !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> -> !ttg.memdesc<256x16xf16, #adj_shared, #adj_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #adj_shared, #adj_smem, mutable, 256x64> -> tensor<256x16xf16, #adj_dot> + scf.yield %v : tensor<256x16xf16, #adj_dot> + } {triton.warp_pipeline.stage = "lds_load"} + + // Stage 2: compute-only — no LDS access + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "compute"} + + scf.yield %a, %s, %ld : tensor<256x16xf16, #adj_dot>, !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable>, tensor<256x16xf16, #adj_dot> + } {triton.warp_pipeline.pipelined_for} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #adj_shared, #adj_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @adjacent_stage_lds_dep +// CHECK: scf.for +// +// Stage 0 ops (local_store). +// CHECK: ttg.local_store +// +// Barrier between stage0→stage1 is LOCAL (adjacent RAW: write→read). +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: ttg.barrier local +// CHECK-NEXT: rocdl.sched.barrier +// +// Stage 1 ops (local_load). +// CHECK: ttg.local_load +// +// Barrier between stage1→stage2 is LOCAL (distance-2 WAR: a1 reads, a0 writes). +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: ttg.barrier local +// CHECK-NEXT: rocdl.sched.barrier +// +// Stage 2 ops (global store). +// CHECK: tt.store +// +// Wrap-around barrier is s_barrier only (a2 has no LDS, a0 writes — no dep). +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: rocdl.s.barrier +// CHECK-NEXT: rocdl.sched.barrier +// +// CHECK: scf.yield + +// ----- + +// ---- Back-to-back: cross-pipeline dep in a later flat stage (b_1) ---- +// +// This test exercises `collectNextPipelineClusters` when the next pipeline is +// a flat (unrolled) sequence of more than one stage. Before the fix, only +// the first stage (b_0) was collected, so a cross-pipeline dependency +// involving a later stage (b_1, b_2, …) was missed and the boundary barriers +// were wrongly eliminated. +// +// Layout: +// Loop A (2 stages): a_0 tt.store (no LDS) +// a_1 ttg.local_load (READS LDS) +// Flat B (2 stages): b_0 tt.store (no LDS) +// b_1 ttg.local_store (WRITES the same LDS buffer) +// +// A's circular analysis finds no intersecting pair (a_1's read does not +// conflict with itself or with a_0), so all of A's bars are non-LOCAL. +// In particular the wrap-around bars[0] is FALSE, so it cannot seed +// coverage for the merged boundary slot. +// +// Cross-pipeline dep: (a_1, b_1) WAR at merged distance 2, barrierLoc = K = 2 +// (the boundary). No other slot on the path from a_1 to b_1 is LOCAL, so +// the analysis must flag the boundary and preserve the post-loop +// cond_barrier, prelude ttg.barrier local, and phase-shift cond_barrier. +// +// Before the collectNextPipelineClusters fix, the boundary barriers would +// have been removed (false negative) because only b_0 was collected, making +// b_1 invisible to the cross-pipeline analysis. + +#crossb_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#crossb_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#crossb_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#crossb_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @cross_pipeline_dep_in_b1( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>>, + %ptr: tensor<256x64x!tt.ptr, #crossb_blocked>, + %gptr: !tt.ptr, + %dst: tensor<256x16x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>>) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #crossb_shared, #crossb_smem, mutable> + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + + // Loop A: stage 0 no LDS, stage 1 reads %smem. The loaded value is + // threaded through iter_args + used after the loop so the execute_region + // (and its ttg.local_load) survives DCE before the redundant-barrier pass. + %final = scf.for %i = %lb to %ub step %step + iter_args(%cur = %acc) + -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> : i32 { + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "a_compute"} + + %ld = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %smem[0, 0] : !ttg.memdesc<256x64xf16, #crossb_shared, #crossb_smem, mutable> -> !ttg.memdesc<256x16xf16, #crossb_shared, #crossb_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #crossb_shared, #crossb_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "a_load"} + + scf.yield %ld : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> + } {triton.warp_pipeline.pipelined_for} + + // Flat B: b_0 no LDS (masks the bug), b_1 writes the same %smem (dep). + scf.execute_region no_inline { + tt.store %gptr, %v1 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "b_nolds"} + + scf.execute_region no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #crossb_blocked> + ttg.local_store %data, %smem : tensor<256x64xf16, #crossb_blocked> -> !ttg.memdesc<256x64xf16, #crossb_shared, #crossb_smem, mutable> + scf.yield + } {triton.warp_pipeline.stage = "b_lds"} + + // Use %final after flat B so the loop's iter_arg result is observed and + // the local_load execute_region survives DCE — without breaking the + // back-to-back boundary between loop A and flat B. + tt.store %dst, %final : tensor<256x16x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #crossb_mma, kWidth = 4}>> + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #crossb_shared, #crossb_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @cross_pipeline_dep_in_b1 +// Pre-barrier and phase shift for loop A. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Loop body: a_0 (tt.store), internal s_barrier, a_1 (local_load). +// CHECK: tt.store +// CHECK: rocdl.s.barrier +// CHECK: ttg.local_load +// Boundary barriers between loop A and flat B are KEPT because (a_1, b_1) +// is a cross-pipeline WAR dep on %smem and no LOCAL barrier on the path +// a_1 → boundary → b_0 → b_1 covers it (A's wrap-around is not LOCAL). +// CHECK: amdg.cond_barrier +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// Flat B stages: b_0 (tt.store), internal s_barrier, b_1 (local_store). +// CHECK: tt.store +// CHECK: ttg.local_store +// Reconverge cond_barrier for flat B. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- Back-to-back: cross-pipeline dep where placement falls inside A ---- +// +// Companion to @cross_pipeline_dep_in_b1. Where that test puts the +// uncovered cross-pipeline pair at distance == 1 (so the placement falls at +// boundary slot K), this one engineers a pair at distance == K from `a_0` +// to `b_0` so the placement falls at slot K-1 — *inside* A's body. +// isCrossPipelineSafe must still flag this as unsafe: the explicit +// cross-pipeline-pair sweep walks (src, barrierLoc] for coverage and finds +// no LOCAL slot in A (loopBars[1..K-1] are all false). +// +// Layout: +// Loop A (2 stages): a_0 ttg.local_load (READS LDS) +// a_1 tt.store (no LDS) +// Flat B (2 stages): b_0 ttg.local_store (WRITES the same LDS buffer) +// b_1 tt.store (no LDS) +// +// A's circular analysis: a_0 read-read with itself, no intersection with a_1; +// loopBars = [false, false] and the wrap-around is non-LOCAL. +// +// Cross-pipeline dep (a_0, b_0) WAR on %smem at merged distance K=2 → +// barrierLoc = dst-1 = 1. isCovered(0, 1) walks slot 1 (loopBars[1]=false) +// and returns false; the pair is intersected → unsafe. Boundary barriers +// must be kept. + +#crossa_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#crossa_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#crossa_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#crossa_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @cross_pipeline_dep_in_a0( + %lb: i32, %ub: i32, %step: i32, + %acc: tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>>, + %ptr: tensor<256x64x!tt.ptr, #crossa_blocked>, + %gptr: !tt.ptr, + %dst: tensor<256x16x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>>) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #crossa_shared, #crossa_smem, mutable> + %v0 = arith.constant 0.0 : f32 + + // Loop A: stage 0 reads %smem (threaded through iter_args so the + // local_load survives DCE), stage 1 no LDS. + %final = scf.for %i = %lb to %ub step %step + iter_args(%cur = %acc) + -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> : i32 { + %ld = scf.execute_region -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> no_inline { + %sub = ttg.memdesc_subslice %smem[0, 0] : !ttg.memdesc<256x64xf16, #crossa_shared, #crossa_smem, mutable> -> !ttg.memdesc<256x16xf16, #crossa_shared, #crossa_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #crossa_shared, #crossa_smem, mutable, 256x64> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> + scf.yield %v : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> + } {triton.warp_pipeline.stage = "a_load"} + + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "a_compute"} + + scf.yield %ld : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> + } {triton.warp_pipeline.pipelined_for} + + // Flat B: b_0 writes %smem (the dep), b_1 no LDS. + scf.execute_region no_inline { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #crossa_blocked> + ttg.local_store %data, %smem : tensor<256x64xf16, #crossa_blocked> -> !ttg.memdesc<256x64xf16, #crossa_shared, #crossa_smem, mutable> + scf.yield + } {triton.warp_pipeline.stage = "b_lds"} + + scf.execute_region no_inline { + tt.store %gptr, %v0 : !tt.ptr + scf.yield + } {triton.warp_pipeline.stage = "b_nolds"} + + tt.store %dst, %final : tensor<256x16x!tt.ptr, #ttg.dot_op<{opIdx = 0, parent = #crossa_mma, kWidth = 4}>> + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #crossa_shared, #crossa_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @cross_pipeline_dep_in_a0 +// Pre-barrier and phase shift for loop A. +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// CHECK: scf.for +// Loop body: a_0 (local_load), internal s_barrier, a_1 (tt.store). +// CHECK: ttg.local_load +// CHECK: rocdl.s.barrier +// CHECK: tt.store +// Boundary barriers between loop A and flat B must be KEPT. The (a_0, b_0) +// WAR on %smem at merged distance K places at slot K-1 (inside A); the +// cross-pipeline-pair sweep finds no LOCAL slot in (0, K-1] (loopBars[1] is +// false because A's intra-cluster barrier is just s_barrier) and reports +// the pair as uncovered. +// CHECK: amdg.cond_barrier +// CHECK: ttg.barrier local +// CHECK: amdg.cond_barrier +// Flat B stages: b_0 (local_store), internal s_barrier, b_1 (tt.store). +// CHECK: ttg.local_store +// CHECK: tt.store +// Reconverge cond_barrier for flat B. +// CHECK: amdg.cond_barrier +// CHECK: tt.return + +// ----- + +// ---- LDS effect nested inside scf.if must be detected ---- +// +// Stage 0 wraps its ttg.local_store inside an scf.if, so the effect is not +// visible on the top-level op. buildBlockInfoFromBlock must walk +// recursively to discover it; otherwise the cross-cluster RAW (stage0 +// writes, stage1 reads) is missed and the cluster barriers degrade from +// ttg.barrier local to plain rocdl.s.barrier — leaving the LDS race +// uncovered. + +#nest_blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#nest_mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16, 16], isTransposed = true}> +#nest_dot = #ttg.dot_op<{opIdx = 0, parent = #nest_mma, kWidth = 4}> +#nest_shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#nest_smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func @nested_lds_effect_in_if( + %lb: i32, %ub: i32, %step: i32, + %cond: i1, + %acc: tensor<256x16xf16, #nest_dot>, + %ptr: tensor<256x64x!tt.ptr, #nest_blocked>) { + + %smem = ttg.local_alloc : () -> !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> + + %r:2 = scf.for %i = %lb to %ub step %step + iter_args(%a = %acc, %s = %smem) + -> (tensor<256x16xf16, #nest_dot>, !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable>) : i32 { + + // Stage 0: conditionally writes LDS via scf.if. The ttg.local_store + // sits inside the if body, so a flat scan of the cluster body would + // miss it. + %st = scf.execute_region -> !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> no_inline { + scf.if %cond { + %data = tt.load %ptr : tensor<256x64x!tt.ptr, #nest_blocked> + ttg.local_store %data, %s : tensor<256x64xf16, #nest_blocked> -> !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> + } + scf.yield %s : !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> + } {triton.warp_pipeline.stage = "cond_store"} + + // Stage 1: reads LDS — RAW with the conditional write in stage 0. + %ld = scf.execute_region -> tensor<256x16xf16, #nest_dot> no_inline { + %sub = ttg.memdesc_subslice %s[0, 0] : !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> -> !ttg.memdesc<256x16xf16, #nest_shared, #nest_smem, mutable, 256x64> + %v = ttg.local_load %sub : !ttg.memdesc<256x16xf16, #nest_shared, #nest_smem, mutable, 256x64> -> tensor<256x16xf16, #nest_dot> + scf.yield %v : tensor<256x16xf16, #nest_dot> + } {triton.warp_pipeline.stage = "lds_load"} + + scf.yield %ld, %s : tensor<256x16xf16, #nest_dot>, !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> + } {triton.warp_pipeline.pipelined_for} + + ttg.local_dealloc %smem : !ttg.memdesc<256x64xf16, #nest_shared, #nest_smem, mutable> + tt.return + } +} + +// CHECK-LABEL: tt.func @nested_lds_effect_in_if +// CHECK: scf.for +// Stage 0 with the nested scf.if + local_store. +// CHECK: scf.if +// CHECK: ttg.local_store +// Cluster barrier between stage 0 and stage 1 is LOCAL (nested write seen). +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: ttg.barrier local +// CHECK-NEXT: rocdl.sched.barrier +// Stage 1 reads LDS. +// CHECK: ttg.local_load +// Wrap-around barrier is also LOCAL (stage1 read vs stage0 write next iter). +// CHECK: rocdl.sched.barrier +// CHECK-NEXT: ttg.barrier local +// CHECK-NEXT: rocdl.sched.barrier +// CHECK: scf.yield diff --git a/test/TritonGPU/amd/amd-warp-pipeline-invalid.mlir b/test/TritonGPU/amd/amd-warp-pipeline-invalid.mlir new file mode 100644 index 000000000000..65c4bdd00512 --- /dev/null +++ b/test/TritonGPU/amd/amd-warp-pipeline-invalid.mlir @@ -0,0 +1,117 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-warp-pipeline -verify-diagnostics + +// Loops are not allowed inside a warp_pipeline_stage region; see isLoopOp +// in WarpPipeliner.cpp for the rationale (no scheduling benefit, opaque to +// MemoryEffectOpInterface, also covers the "no nested warp pipelines" +// rule). Both the loop-form (createPipeline) and flat-form +// (createFlatPipeline) must reject loops between borders. + +// ---- Loop-form: scf.for inside a stage ---- + +tt.func @loop_form_for_in_cluster(%n: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + scf.for %i = %c0 to %n step %c1 { + %a = arith.addi %i, %c1 : index + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + + // expected-error @+1 {{loop op cannot appear inside a warp_pipeline_stage region}} + scf.for %j = %c0 to %n step %c1 { + scf.yield + } + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + %b = arith.addi %a, %i : index + + scf.yield + } + + tt.return +} + +// ----- + +// ---- Loop-form: scf.while inside a stage ---- + +tt.func @loop_form_while_in_cluster(%n: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + scf.for %i = %c0 to %n step %c1 { + %a = arith.addi %i, %c1 : index + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + + // expected-error @+1 {{loop op cannot appear inside a warp_pipeline_stage region}} + scf.while (%w = %c0) : (index) -> index { + %cond = arith.cmpi slt, %w, %n : index + scf.condition(%cond) %w : index + } do { + ^bb0(%w: index): + %wn = arith.addi %w, %c1 : index + scf.yield %wn : index + } + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + %b = arith.addi %a, %i : index + + scf.yield + } + + tt.return +} + +// ----- + +// ---- Loop-form: nested warp-pipelined scf.for is still a loop ---- +// +// Even an already-pipelined inner loop is rejected: nesting warp pipelines +// is a hard constraint, and the loop-op check enforces it for free. + +tt.func @loop_form_nested_pipelined_for(%n: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + scf.for %i = %c0 to %n step %c1 { + %a = arith.addi %i, %c1 : index + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + + // expected-error @+1 {{loop op cannot appear inside a warp_pipeline_stage region}} + scf.for %j = %c0 to %n step %c1 { + scf.execute_region { + scf.yield + } {triton.warp_pipeline.stage = "inner"} + scf.yield + } {triton.warp_pipeline.pipelined_for} + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage"} + %b = arith.addi %a, %i : index + + scf.yield + } + + tt.return +} + +// ----- + +// ---- Flat-form: scf.for between flat borders ---- + +tt.func @flat_form_for_in_cluster(%n: index, %ptr: !tt.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %v0 = arith.constant 0.0 : f32 + + tt.store %ptr, %v0 : !tt.ptr + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage0"} + + // expected-error @+1 {{loop op cannot appear inside a warp_pipeline_stage region}} + scf.for %j = %c0 to %n step %c1 { + scf.yield + } + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage1"} + tt.store %ptr, %v0 : !tt.ptr + + tt.return +} diff --git a/test/TritonGPU/amd/amd-warp-pipeline.mlir b/test/TritonGPU/amd/amd-warp-pipeline.mlir index 7607c1f612f7..e624cbd38daa 100644 --- a/test/TritonGPU/amd/amd-warp-pipeline.mlir +++ b/test/TritonGPU/amd/amd-warp-pipeline.mlir @@ -143,6 +143,66 @@ tt.func public @triple_buf_two_stages(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tt.return } +// -- Flat (unrolled) pipeline: borders outside scf.for ---- +// +// Simulates a static_range epilogue that was unrolled at the Python level +// following a regular pipelined main loop. The flat backward walk must stop +// at the prior scf.for (loops are disallowed inside a stage) so the main +// loop is not absorbed into stage 0. + +tt.func @flat_pipeline_example(%n: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Pipelined main loop: gets the pipelined_for attribute and acts as a + // hard boundary for the flat epilogue's backward walk. + scf.for %i = %c0 to %n step %c1 { + %x = arith.addi %i, %c1 : index + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "load"} + %y = arith.muli %x, %c1 : index + scf.yield + } + + // Stage 0 (ops before the first epilogue border) + %a = arith.addi %c0, %c1 : index + %a2 = arith.muli %a, %c1 : index + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage0_epi", triton.warp_pipeline.priority = 1 : i32} + + // Stage 1 + %b = arith.addi %a2, %c0 : index + %b2 = arith.muli %b, %c1 : index + + rocdl.sched.barrier 0 {triton.warp_pipeline.border = "stage1_epi", triton.warp_pipeline.priority = 0 : i32} + + tt.return +} + +// CHECK-LABEL: tt.func @flat_pipeline_example( +// Pipelined main loop forms its own warp pipeline (one execute_region per +// stage, then the pipelined_for attribute on the loop). +// CHECK: scf.for +// CHECK: scf.execute_region +// CHECK: scf.execute_region +// CHECK: triton.warp_pipeline.pipelined_for +// Flat epilogue execute_regions created from the borders. Crucially, they +// must NOT absorb the pipelined main loop above. +// CHECK: scf.execute_region +// CHECK: arith.addi +// CHECK: arith.muli +// CHECK: scf.yield +// CHECK: triton.warp_pipeline.priority = 1 +// CHECK-SAME: triton.warp_pipeline.stage = "stage0_epi" +// CHECK: scf.execute_region +// CHECK: arith.addi +// CHECK: arith.muli +// CHECK: scf.yield +// CHECK: triton.warp_pipeline.priority = 0 +// CHECK-SAME: triton.warp_pipeline.stage = "stage1_epi" +// Border markers must be erased: +// CHECK-NOT: rocdl.sched.barrier +// CHECK: tt.return + // -- Post-unroll IV remap is sunk past ignorable ops (FA-kernel pattern) ---- // The FA kernel body begins with async_wait. After MLIR loop unrolling, IV // remap ops (arith.addi/muli) land between the last border of iter N and the diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp index 4fbd925fd6b1..59624ddde3d9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp @@ -53,36 +53,192 @@ namespace mlir::triton { namespace { -// construct a virtual block from each pipeline cluster -// block contains its buffer R/W information. +// Construct a virtual block describing a pipeline cluster's buffer R/W set. +// Walks recursively so that LDS effects inside nested non-loop regions +// (scf.if / tt.reduce / tt.scan / etc.) are accounted for. Loops (scf.for / +// scf.while) cannot legally appear inside a cluster, so this walk never has +// to reason about iteration-multiplied effects. static BlockInfo buildBlockInfoFromBlock(Block *block, Allocation *allocation) { - BlockInfo info; // running fact for this block - for (Operation &opRef : *block) { - Operation *op = &opRef; - if (auto mei = dyn_cast(op)) { - SmallVector> effs; - mei.getEffects(effs); - for (auto &eff : effs) { - if (Value v = eff.getValue()) { - for (auto bufId : allocation->getAllBufferIdsWithAliases(v)) { - if (bufId == Allocation::InvalidBufferId) - continue; - auto interval = allocation->getAllocatedInterval(bufId); - auto slice = AllocationSlice(v, interval, bufId); - if (isa(eff.getEffect())) - info.syncWriteSlices[slice].insert(op); - else if (isa(eff.getEffect())) - info.syncReadSlices[slice].insert(op); - } - } + BlockInfo info; + block->walk([&](MemoryEffectOpInterface mei) { + Operation *op = mei.getOperation(); + SmallVector> effs; + mei.getEffects(effs); + for (auto &eff : effs) { + Value v = eff.getValue(); + if (!v) + continue; + for (auto bufId : allocation->getAllBufferIdsWithAliases(v)) { + if (bufId == Allocation::InvalidBufferId) + continue; + auto interval = allocation->getAllocatedInterval(bufId); + auto slice = AllocationSlice(v, interval, bufId); + if (isa(eff.getEffect())) + info.syncWriteSlices[slice].insert(op); + else if (isa(eff.getEffect())) + info.syncReadSlices[slice].insert(op); } } - } + }); return info; } -static void emitClusterBarrier(PatternRewriter &r, Location loc, - bool needLocal) { +// Pre-existing barrier/wait ops that may legally appear at cluster +// boundaries (between stages or before/after a pipeline). Mirrors +// isPipelineIgnorable in WarpPipeliner.cpp plus the ROCDL-lowered forms that +// can appear after intermediate passes. +static bool isWarpPipelineIgnorableBarrier(Operation *op) { + return isa(op); +} + +// True if `exec` is a stage created by the warp-pipeline frontend. +static bool isPipelineStage(scf::ExecuteRegionOp exec) { + return exec && exec->hasAttr("triton.warp_pipeline.stage"); +} + +// dyn_cast + warp_pipeline.stage marker check. +// Returns null when `op` is not a pipeline stage. +static scf::ExecuteRegionOp getPipelineStage(Operation *op) { + auto exec = dyn_cast_or_null(op); + return isPipelineStage(exec) ? exec : nullptr; +} + +// Validate the body of a `pipelined_for` loop. After WarpPipeliner the body +// must consist of: a sequence of pipeline-stage execute_regions, optional +// pre-existing barrier/wait ops between (or before/after) those stages, and +// a terminator scf.yield -- nothing else. Emits an error and returns +// failure on any deviation. Side-effect free: leaves the IR untouched so +// callers can fail fast before mutating anything. +static LogicalResult validatePipelinedForBody(scf::ForOp forOp) { + std::map existingBarrierMap; + int numClusters = 0; + for (auto &op : *forOp.getBody()) { + if (auto exeOp = dyn_cast(op)) { + if (!isPipelineStage(exeOp)) + return op.emitError( + "non-warp-pipeline scf.execute_region inside pipelined_for body"); + ++numClusters; + } else if (isWarpPipelineIgnorableBarrier(&op)) { + if (existingBarrierMap.count(numClusters)) + return op.emitError("multiple pre-existing barriers between pipeline " + "stages; insert a dummy stage instead"); + existingBarrierMap[numClusters] = &op; + } else if (isa(op)) { + continue; + } else { + return op.emitError("unexpected op inside pipelined_for body; only " + "warp-pipeline stages and barrier/wait ops are " + "allowed"); + } + } + if (numClusters < 2) + return forOp.emitError( + "pipelined_for body must contain at least two pipeline stages"); + if (existingBarrierMap.count(0) && existingBarrierMap.count(numClusters)) + return forOp.emitError("pipelined_for body has both top-of-loop and " + "bottom-of-loop pre-existing barriers"); + return success(); +} + +// Pairwise LDS-dependency analysis between pipeline clusters. +// +// `circular` selects the index topology used by the analysis: +// * true — the schedule wraps modulo N. Used by loop pipelines (scf.for) +// where the wrap-around represents iter-i feeding iter-(i+1). +// * false — the schedule is straight-line, indices stay in [0, N). Used +// by flat (unrolled) pipelines. +// The rest of this comment uses "circular" / "linear" exclusively, since +// the analysis only cares about topology and not about the source IR kind. +// +// LAYOUT +// ------ +// cluster: c0 c1 c2 ... c_{N-1} +// bars: b0 b1 b2 b3 ... b_{N-1} (b_i sits +// before c_i) +// +// * circular: b0 is the wrap-around barrier inside the loop body — +// sitting between c_{N-1} of one iteration and c0 of the next. +// * linear: b0 has no physical slot (no barrier exists before the first +// cluster), and the schedule never wraps around. +// +// GOAL +// ---- +// For every ordered pair (src, dst) whose LDS effects intersect, guarantee +// that the schedule has at least one LOCAL (ds_wait + s_barrier) barrier +// somewhere on the path src → dst. If no existing slot on the path is +// LOCAL, mark one as LOCAL. +// +// PLACEMENT CHOICE +// ---------------- +// When forced to place a LOCAL barrier we pick: +// dist == 1 → bars[dst] (the only slot between src and dst) +// dist > 1 → bars[dst - 1] (the second-rightmost slot on the path) +// The `dst - 1` choice is somewhat arbitrary — any slot in (src, dst] is +// correct for memory ordering — and is preserved here to match upstream +// behavior and existing tests. +// +// COVERAGE CHECK +// -------------- +// A pair is "covered" if any slot in (src, barrierLoc] is already LOCAL. +// Note that bars[dst] is intentionally NOT consulted when dist > 1; this +// mirrors the placement choice (we never look at, nor place into, the +// slot owned by the adjacent (dst-1, dst) pair). +// +// ITERATION ORDER +// --------------- +// We sweep `dist` from 1 up to `maxDist`: +// * circular: maxDist = N. dist == N is the self-loop (src == dst), +// which captures iter-i write vs iter-(i+1) read across the +// wrap-around when only one cluster touches the buffer. +// * linear: maxDist = N - 1. No wrap. +// Walking by increasing distance ensures the shorter-range LOCAL +// barriers we just placed are visible when checking longer-range pairs, +// skipping many redundant placements. +static void analyzePipelineDependencies(ArrayRef clusterInfo, + SmallVectorImpl &bars, + Allocation *allocation, bool circular) { + const int N = clusterInfo.size(); + const int maxDist = circular ? N : N - 1; + + // Modular wrap; a no-op in linear mode where indices stay in range. + auto wrap = [&](int i) -> int { return circular ? (i % N + N) % N : i; }; + + // Returns true if any barrier slot in (src, stop] is already LOCAL. + // The walk starts at `src + 1` and advances one slot at a time, wrapping + // modulo N in circular mode; it terminates as soon as it finds a LOCAL + // slot or reaches `stop`. + auto isCovered = [&](int src, int stop) -> bool { + for (int i = src + 1;; i++) { + const int idx = wrap(i); + if (bars[idx]) + return true; + if (idx == stop) + return false; + } + }; + + for (int dist = 1; dist <= maxDist; dist++) { + // In linear mode, src + dist must stay in range. In circular mode all + // src values are valid and dst wraps modulo N. + const int srcEnd = circular ? N : N - dist; + for (int src = 0; src < srcEnd; src++) { + const int dst = wrap(src + dist); + const int barrierLoc = (dist == 1) ? dst : wrap(dst - 1); + if (isCovered(src, barrierLoc)) + continue; + if (!clusterInfo[src].isIntersected( + clusterInfo[dst], mlir::triton::AMD::membarFilter, allocation)) + continue; + bars[barrierLoc] = true; + LDBG("cluster " << src << " need fence to " << dst + << " placing barrier at " << barrierLoc); + } + } +} + +static void emitClusterBarrier(OpBuilder &r, Location loc, bool needLocal) { ROCDL::SchedBarrier::create(r, loc, 0); if (needLocal) mlir::triton::gpu::BarrierOp::create(r, loc, triton::gpu::AddrSpace::Local); @@ -91,7 +247,7 @@ static void emitClusterBarrier(PatternRewriter &r, Location loc, ROCDL::SchedBarrier::create(r, loc, 0); } -static void emitClusterPriority(PatternRewriter &r, Location loc, +static void emitClusterPriority(OpBuilder &r, Location loc, Operation *clusterOp, bool anyHasPriority) { if (auto intAttr = clusterOp->getAttrOfType( "triton.warp_pipeline.priority")) { @@ -102,6 +258,54 @@ static void emitClusterPriority(PatternRewriter &r, Location loc, } } +// Wrap a pre-existing barrier op (e.g. async_wait) with sched_barriers so the +// backend scheduler cannot move ops across it, and emit the cluster's +// priority just before the barrier. Used in place of inserting a fresh +// cluster barrier when one already exists at the cluster boundary. +static void wrapExistingBarrier(OpBuilder &b, Location loc, + Operation *clusterOp, + Operation *existingBarrier, + bool anyHasPriority) { + b.setInsertionPoint(existingBarrier); + emitClusterPriority(b, loc, clusterOp, anyHasPriority); + ROCDL::SchedBarrier::create(b, loc, 0); + b.setInsertionPointAfter(existingBarrier); + ROCDL::SchedBarrier::create(b, loc, 0); +} + +// Emit pre-barrier, thread-ID partitioning, and phase-shift cond_barrier. +// Returns warpLow (for reconverge) and warpHigh (consumed by phase shift). +static std::pair +emitPipelinePrelude(OpBuilder &b, Location loc, int threadsPerPipelineGroup) { + // Flush any pending shared-memory (LDS) dependencies before entering the + // warp-pipelined region. Without this barrier ModuleMembarAnalysis may + // later insert a barrier inside the first pipeline stage, which would + // break the carefully tuned pipeline timing. + mlir::triton::gpu::BarrierOp::create(b, loc, triton::gpu::AddrSpace::Local); + + auto i32ty = b.getIntegerType(32); + auto workIDX = ROCDL::ThreadIdXOp::create(b, loc, i32ty); + auto constZero = arith::ConstantIntOp::create(b, loc, 0, 32); + auto constWarpSize = + arith::ConstantIntOp::create(b, loc, threadsPerPipelineGroup, 32); + auto warpIDX = arith::DivSIOp::create(b, loc, workIDX, constWarpSize); + auto warpLow = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + warpIDX, constZero); + auto warpHigh = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, + warpIDX, constZero); + mlir::triton::amdgpu::CondBarrierOp::create(b, loc, warpHigh); + + return {warpLow, warpHigh}; +} + +// Emit priority reset and reconverge cond_barrier after a pipeline. +static void emitPipelinePostlude(OpBuilder &b, Location loc, + bool anyHasPriority, Value warpLow) { + if (anyHasPriority) + ROCDL::SetPrioOp::create(b, loc, 0); + mlir::triton::amdgpu::CondBarrierOp::create(b, loc, warpLow); +} + class ConvertPipelinedForPattern : public OpRewritePattern { public: ConvertPipelinedForPattern(MLIRContext *ctx, ModuleAllocation &moduleAlloc, @@ -115,53 +319,32 @@ class ConvertPipelinedForPattern : public OpRewritePattern { // Only handle loops that the frontend marked with pipelined_for. if (!forOp->getAttr("triton.warp_pipeline.pipelined_for")) return rewriter.notifyMatchFailure(forOp, "no pipelined_for"); - forOp->removeAttr("triton.warp_pipeline.pipelined_for"); - // Look up allocation info as in original pass. + // Look up allocation info as in original pass. Bail out *before* we + // mutate the IR so a soft match-failure cannot leave a half-converted + // loop behind (no marker, no barriers). auto func = forOp->getParentOfType(); Allocation *allocation = moduleAllocation.getFuncData(func); if (!allocation) return rewriter.notifyMatchFailure(forOp, "no Allocation for function"); - if (failed(emitPipelinedFor(rewriter, forOp.getLoc(), forOp, allocation, - threadsPerPipelineGroup))) - return failure(); - + forOp->removeAttr("triton.warp_pipeline.pipelined_for"); + emitPipelinedFor(rewriter, forOp.getLoc(), forOp, allocation, + threadsPerPipelineGroup); return success(); } private: - LogicalResult emitPipelinedFor(PatternRewriter &b, Location loc, - scf::ForOp forOp, Allocation *allocation, - int threadsPerPipelineGroup) const { - // 1. Insert conditional branch first, + void emitPipelinedFor(PatternRewriter &b, Location loc, scf::ForOp forOp, + Allocation *allocation, + int threadsPerPipelineGroup) const { + // 1. Pre-barrier, thread partitioning, and phase shift. b.setInsertionPoint(forOp); - // Set barrier before starting the loop. This resolves any outstanding - // synchronization before beginning the specialized asymmetric - // synchronization. - mlir::triton::gpu::BarrierOp::create(b, loc, triton::gpu::AddrSpace::Local); - - // Insert condbarrier::second_half before starting the loop - // FIXME : correctly calculate numbers per the arch - auto i32ty = b.getIntegerType(32); - auto workIDX = ROCDL::ThreadIdXOp::create(b, loc, i32ty); - auto constZero = arith::ConstantIntOp::create(b, loc, 0, 32); - auto constWarpSize = - arith::ConstantIntOp::create(b, loc, threadsPerPipelineGroup, 32); - auto warpIDX = arith::DivSIOp::create(b, loc, workIDX, constWarpSize); - auto warpLow = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, - warpIDX, constZero); - auto warpHigh = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, - warpIDX, constZero); - - mlir::triton::amdgpu::CondBarrierOp::create(b, loc, warpHigh); - - // 2. Collect existing barrier information. - // Scanning the loop body and classifying each consecutive block of - // operations into a pipeline cluster (one cluster per execute_region). - // While doing this, we also detect any pre-existing barriers located - // between clusters. These barriers may come from prefetch patterns, and - // must be preserved, but only at valid cluster boundaries. + auto [warpLow, warpHigh] = + emitPipelinePrelude(b, loc, threadsPerPipelineGroup); + + // 2. Walk the (already-validated) body once to collect clusters and any + // pre-existing inter-cluster barriers (e.g. from prefetch patterns). SmallVector clusterBlocks; SmallVector clusterOps; SmallVector bars; @@ -170,30 +353,14 @@ class ConvertPipelinedForPattern : public OpRewritePattern { for (auto &op : *forOp.getBody()) { if (auto exeOp = dyn_cast(op)) { - // Fail conversion with executeRegion from unkown source. - if (exeOp->getAttr("triton.warp_pipeline.stage") == nullptr) - return failure(); exeOp.setNoInline(false); clusterOps.push_back(&op); clusterBlocks.push_back(&exeOp->getRegion(0).front()); bars.push_back(false); - } else if (isa(op)) { - int currCluster = clusterBlocks.size(); - // Reject if multiple barriers appear without an intervening cluster. - // This is functionally valid but may cause unpredictable timing. Users - // should insert a dummy cluster explicitly if a pipeline bubble is - // required. - // Also only allow ops which waits local memory, - // e.g., s_barrier is NOT allowed. - if (existingBarrierMap.find(currCluster) != existingBarrierMap.end()) - return failure(); - existingBarrierMap[currCluster] = &op; - } else if (auto yieldOp = dyn_cast(op)) { + } else if (isWarpPipelineIgnorableBarrier(&op)) { + existingBarrierMap[clusterBlocks.size()] = &op; + } else if (isa(op)) { terminatorOp = &op; - } else { // Fail conversion if any other op found outside of the cluster. - return failure(); } } @@ -212,54 +379,19 @@ class ConvertPipelinedForPattern : public OpRewritePattern { // but sometimes required by memory prefetching pattern. auto topBar = existingBarrierMap.find(0); auto bottomBar = existingBarrierMap.find(numClusters); + bool hasTopBarrier = topBar != existingBarrierMap.end(); if (bottomBar != existingBarrierMap.end()) { - if (topBar != existingBarrierMap.end()) - return failure(); // Unreachable + // validatePipelinedForBody guarantees we cannot have both top and + // bottom barriers, so rotating bottom -> 0 is unambiguous. + assert(!hasTopBarrier && + "validatePipelinedForBody should have rejected this"); existingBarrierMap[0] = bottomBar->second; existingBarrierMap.erase(bottomBar); } - // 3. Performing pairwise dependency analysis between clusters. For each - // src → next pair (with wrap-around), we check whether their memory - // intervals overlap. If so, a fence/barrier must be inserted at the - // boundary cluster (barrierLoc). The analysis is expressed as a - // circular traversal so that pipeline stages form a ring. - // • `bars[i] = true` marks that a new cluster barrier must be inserted - // before cluster i. - // • Existing barriers override or satisfy required fences, so we do not - // insert duplicates. - for (int offset = 0; offset < numClusters; offset++) { - for (int src = 0; src < numClusters; src++) { - const int next = (src + 2 + offset) % numClusters; - const int barrierLoc = (src + 1 + offset) % numClusters; - LDBG("Inspecting src:" << src << " to next:" << next); - // Check if any existing barrier sits between src and barrierIdx - auto isSynced = [&]() -> bool { - for (int idx = (src + 1) % numClusters; idx != src; - idx = (idx + 1) % numClusters) { - if (bars[idx]) - return true; - if (idx == barrierLoc) - break; - } - return false; - }; - // Skip if dependency is already resolved. - if (isSynced()) { - LDBG("already synced"); - continue; - } - const bool needFence = clusterInfo[src].isIntersected( - clusterInfo[next], mlir::triton::AMD::membarFilter, allocation); - // insert fence/barrier in front of this cluster - LDBG("need fence?: " << needFence); - if (needFence) { - bars[barrierLoc] = true; - LDBG("cluster " << src << " need fence to " << next - << " placing barrier at " << barrierLoc); - } - } - } + // 3. Circular dependency analysis (wrap-around for loop pipelines). + analyzePipelineDependencies(clusterInfo, bars, allocation, + /*circular=*/true); // 4. Materializing final cluster-scope barriers. For each cluster index: // • If there is a pre-existing barrier at that location, we wrap it with @@ -273,21 +405,26 @@ class ConvertPipelinedForPattern : public OpRewritePattern { // the first cluster barrier must be inserted just before the loop’s // terminator, forming the wrap-around dependency. for (int i = 0; i < numClusters; i++) { + if (i == 0 && !hasTopBarrier) { + // Prime the first iteration's priority. The loop-carried cluster-0 + // barrier sits at the bottom of the loop body, so it only controls + // the next iteration. + b.setInsertionPoint(forOp); + emitClusterPriority(b, loc, clusterOps[i], anyHasPriority); + } + if (auto exBar = existingBarrierMap.find(i); exBar != existingBarrierMap.end()) { - auto exBarOp = exBar->second; - b.setInsertionPoint(exBarOp); - emitClusterPriority(b, loc, clusterOps[i], anyHasPriority); - ROCDL::SchedBarrier::create(b, loc, 0); - b.setInsertionPointAfter(exBarOp); - ROCDL::SchedBarrier::create(b, loc, 0); + // FIXME: If bars[i] is true, wrapping a non-LOCAL pre-existing + // barrier is not enough to satisfy LDS ordering. For now we rely on + // the producer to place such barriers only where no local fence is + // needed. + wrapExistingBarrier(b, loc, clusterOps[i], exBar->second, + anyHasPriority); } else { b.setInsertionPoint(clusterOps[i]); // The first one wraps back to the last of the loop - if (i == 0 && topBar == existingBarrierMap.end()) { - // Extra setprio needed before the loop for the first cluster - b.setInsertionPoint(forOp); - emitClusterPriority(b, loc, clusterOps[i], anyHasPriority); + if (i == 0 && !hasTopBarrier) { // inserts just before yield (=End of the loop). b.setInsertionPoint(terminatorOp); } @@ -296,12 +433,9 @@ class ConvertPipelinedForPattern : public OpRewritePattern { } } - // Insert condbarrier and priority reset after the loop. + // 5. Post-loop priority reset and reconverge. b.setInsertionPointAfter(forOp); - if (anyHasPriority) - ROCDL::SetPrioOp::create(b, loc, 0); - mlir::triton::amdgpu::CondBarrierOp::create(b, loc, warpLow); - return success(); + emitPipelinePostlude(b, loc, anyHasPriority, warpLow); } ModuleAllocation &moduleAllocation; @@ -320,7 +454,7 @@ class InlineWarpPipelineExecuteRegionPattern return rewriter.notifyMatchFailure(exec, "explicit no_inline"); // Only inline the stages created by the warp-pipeline frontend. - if (!exec->getAttr("triton.warp_pipeline.stage")) + if (!isPipelineStage(exec)) return rewriter.notifyMatchFailure(exec, "not a warp-pipeline stage"); // Make sure this pattern is applied after transforming pipelined forOp @@ -346,6 +480,454 @@ class InlineWarpPipelineExecuteRegionPattern } }; +// Process a flat (non-loop) sequence of warp-pipeline execute_regions. +// Unlike the loop case there is no wrap-around: dependencies are strictly +// linear from the first stage to the last. +// +// Emitted IR: +// ttg.barrier local (pre-barrier) +// +// cond_barrier(warpHigh) (phase shift) +// [s_setprio P0] +// execute_region { stage 0 } +// [s_setprio P1] sched+barrier (cluster barrier) +// execute_region { stage 1 } +// ... +// [s_setprio 0] +// cond_barrier(warpLow) (reconverge) +// +static void emitPipelinedFlat(SmallVector &clusterOps, + Allocation *allocation, + int threadsPerPipelineGroup) { + Location loc = clusterOps.front().getLoc(); + OpBuilder b(clusterOps.front().getContext()); + int numClusters = clusterOps.size(); + + // 1. Pre-barrier and phase shift before the first execute_region. + b.setInsertionPoint(clusterOps.front()); + auto [warpLow, warpHigh] = + emitPipelinePrelude(b, loc, threadsPerPipelineGroup); + + // 2. Collect cluster info. + SmallVector clusterBlocks; + SmallVector bars(numClusters, false); + + for (auto exec : clusterOps) { + exec.setNoInline(false); + clusterBlocks.push_back(&exec->getRegion(0).front()); + } + + SmallVector clusterInfo; + for (auto *cb : clusterBlocks) + clusterInfo.push_back(buildBlockInfoFromBlock(cb, allocation)); + + bool anyHasPriority = llvm::any_of(clusterOps, [](scf::ExecuteRegionOp op) { + return op->hasAttr("triton.warp_pipeline.priority"); + }); + + // 3. Linear dependency analysis (no wrap-around for flat pipelines). + analyzePipelineDependencies(clusterInfo, bars, allocation, + /*circular=*/false); + + // 4. Materialize cluster barriers. + // Cluster 0 gets only its priority (inserted after cond_barrier above). + // Clusters 1..N get priority + cluster barrier, unless a pre-existing + // barrier op (e.g., async_wait) already exists between the clusters — + // in that case, wrap it with sched_barriers instead of adding a new one. + emitClusterPriority(b, loc, clusterOps[0], anyHasPriority); + + for (int i = 1; i < numClusters; i++) { + Operation *existingBarrier = nullptr; + for (Operation *op = clusterOps[i - 1]->getNextNode(); + op && op != clusterOps[i].getOperation(); op = op->getNextNode()) { + if (isWarpPipelineIgnorableBarrier(op)) { + existingBarrier = op; + break; + } + } + + if (existingBarrier) { + // FIXME: If bars[i] is true, wrapping a non-LOCAL pre-existing barrier + // is not enough to satisfy LDS ordering. For now we rely on the + // producer to place such barriers only where no local fence is needed. + wrapExistingBarrier(b, loc, clusterOps[i], existingBarrier, + anyHasPriority); + } else { + b.setInsertionPoint(clusterOps[i]); + emitClusterPriority(b, loc, clusterOps[i], anyHasPriority); + emitClusterBarrier(b, loc, /*needLocal=*/bars[i]); + } + } + + // 5. Post-sequence reconverge. + b.setInsertionPointAfter(clusterOps.back()); + emitPipelinePostlude(b, loc, anyHasPriority, warpLow); +} + +// Walk the module for flat warp-pipeline execute_region sequences +// (produced by WarpPipeliner::createFlatPipeline) and emit phase-shift +// barriers around them. +static void processUnrolledPipelineRegions(ModuleOp m, + ModuleAllocation &moduleAllocation, + int threadsPerPipelineGroup) { + m.walk([&](triton::FuncOp funcOp) { + Allocation *allocation = moduleAllocation.getFuncData(funcOp); + if (!allocation) + return; + + // NOTE: We only iterate the function's top-level blocks; flat-pipeline + // execute_regions inside nested non-loop regions (e.g. scf.if bodies) + // are not collected. WarpPipeliner's flat-pipeline frontend has the + // same scope, so the two stay in sync. + for (Block &block : funcOp.getBody()) { + // Collect contiguous sequences of flat warp-pipeline execute_regions, + // splitting at any non-ignorable, non-pipeline op. + SmallVector> sequences; + SmallVector current; + + for (auto &op : block) { + if (auto exec = getPipelineStage(&op)) { + current.push_back(exec); + continue; + } + if (isWarpPipelineIgnorableBarrier(&op)) + continue; + if (!current.empty()) { + sequences.push_back(std::move(current)); + current.clear(); + } + } + if (!current.empty()) + sequences.push_back(std::move(current)); + + for (auto &seq : sequences) { + if (seq.size() < 2) + continue; + LDBG("processing flat pipeline with " << seq.size() << " stages"); + emitPipelinedFlat(seq, allocation, threadsPerPipelineGroup); + } + } + }); +} + +// Return true if `op` is intra-pipeline glue between two clusters — the +// sequence emitted by emitClusterBarrier/emitClusterPriority and any +// pre-existing barrier op that emitPipelinedFlat wraps with sched_barriers. +// Defined in terms of isWarpPipelineIgnorableBarrier so the two sets stay in +// sync; the extra cases below are the ones we emit ourselves. +static bool isIntraPipelineGlue(Operation *op) { + return isWarpPipelineIgnorableBarrier(op) || + isa(op); +} + +// Walk backward from `exec` past `sched_barrier` / `s_setprio` and check +// whether the first non-glue op is a LOCAL `triton::gpu::BarrierOp`. +// Any other barrier kind (s_barrier, async_wait, …) is treated as +// non-LOCAL for the purposes of LDS-dependency coverage. +static bool hasLocalBarrierBefore(Operation *exec) { + for (Operation *scan = exec->getPrevNode(); scan; + scan = scan->getPrevNode()) { + if (isa(scan)) + continue; + if (auto barrier = dyn_cast(scan)) + return barrier.hasLocal(); + return false; + } + return false; +} + +// Collect execute_region clusters and their materialized barrier flags from +// a converted pipelined for-loop body. After ConvertPipelinedForPattern the +// loop body contains: [priority] [barrier] execute_region ... [barrier] yield. +// bars[0] corresponds to the wrap-around barrier (before yield); bars[i] for +// i > 0 is the barrier immediately preceding cluster i. +// Returns false if the loop body doesn't match the expected pattern. +static bool collectLoopClusters(scf::ForOp forOp, + SmallVectorImpl &blocks, + SmallVectorImpl &bars) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + if (!yieldOp) + return false; + for (auto &op : *forOp.getBody()) { + if (auto exec = getPipelineStage(&op)) { + blocks.push_back(&exec->getRegion(0).front()); + bars.push_back(false); + } + } + if (blocks.empty()) + return false; + + int K = blocks.size(); + // bars[0]: wrap-around barrier immediately before the yield. + // Pattern: [s_setprio] sched_barrier (ttg.barrier_local|s_barrier) + // sched_barrier yield + Operation *op = yieldOp->getPrevNode(); + if (op && isa(op)) { + op = op->getPrevNode(); + if (auto barrier = dyn_cast_or_null(op)) + bars[0] = barrier.hasLocal(); + } + + // bars[1..K-1]: barrier immediately preceding each cluster's execute_region. + for (int i = 1; i < K; i++) + bars[i] = hasLocalBarrierBefore(blocks[i]->getParentOp()); + return true; +} + +// Collect execute_region clusters and their preceding barrier flags from a +// flat (unrolled) pipeline starting at `firstExec`. After emitPipelinedFlat +// the sequence looks like: +// exec { b_0 } [s_setprio] sched_barrier (barrier) sched_barrier exec { b_1 } +// ... +// bars[0] is always false (no barrier before the first cluster); bars[i] for +// i > 0 is the barrier between b_{i-1} and b_i. +static bool collectFlatClusters(scf::ExecuteRegionOp firstExec, + SmallVectorImpl &blocks, + SmallVectorImpl &bars) { + if (!isPipelineStage(firstExec)) + return false; + blocks.push_back(&firstExec->getRegion(0).front()); + bars.push_back(false); + + for (Operation *op = firstExec->getNextNode(); op; op = op->getNextNode()) { + if (auto exec = getPipelineStage(op)) { + blocks.push_back(&exec->getRegion(0).front()); + bars.push_back(hasLocalBarrierBefore(op)); + continue; + } + // Walk past cluster barriers / priority / pre-existing barriers that + // emitPipelinedFlat may have wrapped with sched_barriers. Anything + // else (e.g. cond_barrier postlude, unrelated ops) terminates the + // flat sequence. + if (isIntraPipelineGlue(op)) + continue; + break; + } + return true; +} + +// Dispatch to collectLoopClusters / collectFlatClusters based on the kind of +// the next pipeline. The resulting bars follow the same convention as +// collectLoopClusters: bars[0] is either a wrap-around (loop) or false (flat); +// bars[i>0] is the barrier preceding cluster i. +static bool collectNextPipelineClusters(Operation *startOp, + SmallVectorImpl &blocks, + SmallVectorImpl &bars) { + if (auto forOp = dyn_cast(startOp)) + return collectLoopClusters(forOp, blocks, bars); + if (auto exec = dyn_cast(startOp)) + return collectFlatClusters(exec, blocks, bars); + return false; +} + +// Check whether merging two pipelines is safe to do without inserting a +// barrier at the boundary. Enumerates *only* cross-pipeline pairs +// (a_i, b_j) and verifies each intersected pair already has a LOCAL barrier +// on its path in the merged schedule. +// +// Why not reuse analyzePipelineDependencies on the merged sequence? +// The merged linear analysis would also visit intra-A and intra-B pairs, +// and may try to flip an internal slot (e.g. between a_0 and a_1) when +// A's IR has only a non-LOCAL pre-existing barrier (such as +// amdg.async_tdm_wait). Such slots are A's own responsibility — A +// already accepted that wait as sufficient for intra-warp ordering — and +// re-flipping them is a false positive that prevents elimination on +// otherwise safe kernels. Restricting the sweep to cross-pipeline pairs +// sidesteps the ambiguity entirely. +// +// Cross-warp concurrency vs intra-warp ordering: +// With a one-stage phase offset the only truly concurrent cross-warp +// pair at the boundary is (a_{K-1}, b_0). All other (a_i, b_j) pairs +// execute sequentially within a single warp. Both kinds, however, +// require a LOCAL barrier on AMD: the concurrent pair needs cross-warp +// sync, and the sequential pair needs ds_wait to order async ds_read / +// ds_write within a warp (pre-existing async_tdm_wait does *not* +// guarantee ds completion ordering in general). So the coverage check +// uses LOCAL-only mergedBars uniformly across all pairs. +// +// Layout of mergedBars (linear, LOCAL-only): +// i < K A's internal LOCAL barriers (loopBars[i]). +// i == K boundary seed; set to loopBars[0] because A's wrap-around +// physically sits at the bottom of A's loop body and, when +// LOCAL, is the most recent LDS sync the merged schedule +// inherits as it crosses into B. +// i > K B's internal LOCAL barriers (nextBars[i - K]). nextBars[0] +// is skipped (flat B has no slot before b_0; loop B's +// wrap-around lives inside B's body, irrelevant here). +static bool isCrossPipelineSafe(ArrayRef loopBlocks, + ArrayRef loopBars, + ArrayRef nextBlocks, + ArrayRef nextBars, + Allocation *allocation) { + int K = loopBlocks.size(); + int M = nextBlocks.size(); + assert(!loopBars.empty() && + "expected at least one cluster in the prior loop"); + + SmallVector mergedInfo; + for (auto *b : loopBlocks) + mergedInfo.push_back(buildBlockInfoFromBlock(b, allocation)); + for (auto *b : nextBlocks) + mergedInfo.push_back(buildBlockInfoFromBlock(b, allocation)); + + SmallVector mergedBars; + mergedBars.reserve(K + M); + for (bool b : loopBars) + mergedBars.push_back(b); + mergedBars.push_back(loopBars[0]); // boundary, seeded by A's wrap-around + for (int i = 1; i < M; i++) + mergedBars.push_back(nextBars[i]); + + // True if any slot in (src, stop] is LOCAL. Linear topology, no wrap. + auto isCovered = [&](int src, int stop) { + for (int i = src + 1; i <= stop; i++) + if (mergedBars[i]) + return true; + return false; + }; + + // Sweep cross-pipeline pairs only. Placement choice mirrors + // analyzePipelineDependencies (dist == 1 → dst, dist > 1 → dst - 1). + for (int i = 0; i < K; i++) { + for (int j = 0; j < M; j++) { + int src = i, dst = K + j; + int dist = dst - src; + int barrierLoc = (dist == 1) ? dst : dst - 1; + if (isCovered(src, barrierLoc)) + continue; + if (!mergedInfo[src].isIntersected( + mergedInfo[dst], mlir::triton::AMD::membarFilter, allocation)) + continue; + LDBG("cross-pipeline LDS dep (a_" + << i << ", b_" << j << ") uncovered at slot " << barrierLoc); + return false; + } + } + return true; +} + +// Eliminate redundant conditional barriers between consecutive warp-pipelined +// regions. When two pipelines are back-to-back with no intervening +// operations, the post-loop reconverge (cond_barrier warpLow) and the +// pre-pipeline phase shift (cond_barrier warpHigh) cancel out — the phase +// from the first pipeline naturally carries over. +// +// The prelude's ttg.barrier local (see emitPipelinePrelude) exists to flush +// pending LDS state so ModuleMembarAnalysis won't insert barriers inside +// pipeline stages. When the post-loop cond_barrier is immediately followed +// by this barrier and cross-pipeline dependency analysis confirms no LDS +// hazard at the boundary, the barrier is also redundant. +// +// When the two pipelines merge, the phase offset causes stages from different +// pipelines to execute concurrently (e.g., warp0 runs b0 while warp1 runs +// a_{K-1}). The cross-pipeline analysis checks all pairs (a_i, b_j) for LDS +// conflicts, accounting for barriers already placed by each pipeline's own +// dependency analysis. +// +// The "next pipeline" can be either another scf.for or a flat (unrolled) +// pipeline represented as a sequence of scf.execute_region ops. +// TODO: This could be generalized to flat-to-loop / flat-to-flat boundaries, +// but those cases cannot reuse a prior loop's wrap-around barrier as the +// boundary seed and are not expected to matter for common codegen. +// +// Before: After: +// scf.for { loop 1 } scf.for { loop 1 } +// [s_setprio 0] [s_setprio 0] +// cond_barrier(warpLow) ← erase (dead, cleaned later) +// ttg.barrier local ← erase [s_setprio P] +// scf.for / execute_region { pipeline 2 } +// cond_barrier(warpHigh) ← erase +// [s_setprio P] +// scf.for / execute_region { pipeline 2 } +// +static void eliminateRedundantCondBarriers(ModuleOp m, + ModuleAllocation &moduleAllocation) { + SmallVector toErase; + + m.walk([&](triton::FuncOp funcOp) { + Allocation *allocation = moduleAllocation.getFuncData(funcOp); + if (!allocation) + return; + + for (Block &block : funcOp.getBody()) { + SmallVector condBarriers; + for (auto &op : block) + if (auto cb = dyn_cast(&op)) + condBarriers.push_back(cb); + + for (size_t i = 0; i + 1 < condBarriers.size(); i++) { + auto postLoopCB = condBarriers[i]; + auto preLoopCB = condBarriers[i + 1]; + + // The post-loop cond_barrier must be preceded by a scf.for + // (possibly with an intervening s_setprio reset). + Operation *prev = postLoopCB->getPrevNode(); + if (prev && isa(prev)) + prev = prev->getPrevNode(); + if (!isa_and_nonnull(prev)) { + LDBG("post-loop cond_barrier not preceded by scf.for; skipping"); + continue; + } + auto prevFor = cast(prev); + + // The pre-loop cond_barrier must be followed by a warp-pipelined + // scf.for or a flat pipeline execute_region (possibly with an + // intervening s_setprio). + Operation *next = preLoopCB->getNextNode(); + if (next && isa(next)) + next = next->getNextNode(); + bool nextIsPipeline = + isa_and_nonnull(next) || getPipelineStage(next); + if (!nextIsPipeline) { + LDBG("pre-loop cond_barrier not followed by a warp-pipeline; " + "skipping"); + continue; + } + + // The post-loop cond_barrier must be immediately followed by the + // prelude's ttg.barrier local — this proves no operations were + // inserted between the two pipelines. + auto preBarrier = + dyn_cast_or_null(postLoopCB->getNextNode()); + if (!preBarrier || !preBarrier.hasLocal()) { + LDBG("post-loop cond_barrier not immediately followed by prelude " + "ttg.barrier local; skipping"); + continue; + } + + // Cross-pipeline LDS dependency analysis. When the phase carries + // over, stages from different pipelines execute concurrently at the + // boundary. We must verify that no uncovered LDS conflict exists. + SmallVector loopBlocks, nextBlocks; + SmallVector loopBars, nextBars; + if (!collectLoopClusters(prevFor, loopBlocks, loopBars)) { + LDBG("could not collect prior loop's clusters; skipping"); + continue; + } + if (!collectNextPipelineClusters(next, nextBlocks, nextBars)) { + LDBG("could not collect next pipeline's clusters; skipping"); + continue; + } + if (!isCrossPipelineSafe(loopBlocks, loopBars, nextBlocks, nextBars, + allocation)) { + LDBG("cross-pipeline LDS dependency at boundary — keeping barriers"); + continue; + } + + LDBG("eliminating redundant barriers between back-to-back pipelines"); + toErase.push_back(postLoopCB); + toErase.push_back(preBarrier); + toErase.push_back(preLoopCB); + i++; + } + } + }); + + for (auto *op : llvm::reverse(toErase)) + op->erase(); +} + struct ConvertWarpPipeline : public mlir::triton::impl::ConvertWarpPipelineBase { @@ -376,6 +958,20 @@ struct ConvertWarpPipeline // stages at different times. int threadsPerPipelineGroup = targetInfo.getWarpSize() * 4; + // Up-front structural validation: catch malformed pipelined_for bodies + // before any rewrite mutates the IR. Errors are emitted at the + // offending op; we bail out hard rather than producing half-converted + // IR. + bool malformed = false; + m.walk([&](scf::ForOp forOp) { + if (!forOp->getAttr("triton.warp_pipeline.pipelined_for")) + return; + if (failed(validatePipelinedForBody(forOp))) + malformed = true; + }); + if (malformed) + return signalPassFailure(); + RewritePatternSet patternFor(&getContext()); RewritePatternSet patternInline(&getContext()); patternFor.add(&getContext(), moduleAllocation, @@ -383,9 +979,20 @@ struct ConvertWarpPipeline patternInline.add(&getContext()); if (failed(applyPatternsGreedily(m, std::move(patternFor)))) - signalPassFailure(); + return signalPassFailure(); + + // Flat (unrolled) pipeline regions are still wrapped in execute_regions + // with no_inline=true from WarpPipeliner. Process them before inlining. + processUnrolledPipelineRegions(m, moduleAllocation, + threadsPerPipelineGroup); + + // Must run after patternFor and flat processing (all regions converted, + // barriers inserted) but before patternInline (inlining execute_regions + // would flatten the IR and obscure the cond_barrier adjacency we rely on). + eliminateRedundantCondBarriers(m, moduleAllocation); + if (failed(applyPatternsGreedily(m, std::move(patternInline)))) - signalPassFailure(); + return signalPassFailure(); } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/WarpPipeliner.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/WarpPipeliner.cpp index 40d566f40945..431b00a1bf61 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/WarpPipeliner.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/WarpPipeliner.cpp @@ -26,6 +26,60 @@ namespace mlir { #define GEN_PASS_DEF_TRITONAMDGPUWARPPIPELINE #include "TritonAMDGPUTransforms/Passes.h.inc" +// Ops that may appear between pipeline stages but never inside one. Pre- +// existing memory-fence/wait ops at cluster boundaries are tolerated so that +// prefetch patterns continue to work; encountering one mid-cluster is treated +// as malformed input by the callers. +static bool canSitBetweenStages(Operation *op) { + return isa(op); +} + +// True if `op` carries the cluster-end marker emitted by the frontend. +static bool isPipelineBorder(Operation *op) { + return op->hasAttr("triton.warp_pipeline.border"); +} + +// True if `op` is a structured loop (scf.for / scf.while). Pipeline clusters +// are straight-line scheduling units, so loops remain boundaries instead of +// being absorbed. This also rejects nested warp-pipelined scf.for ops. +static bool isLoopOp(Operation *op) { + return isa(op); +} + +// Outcome of attempting to build a pipeline from a region. +// NotApplicable: no border markers were present (the region opted out). +// Created: a pipeline was successfully materialized. +// Malformed: border markers were present but the pipeline could not be +// built; an error has been emitted at the offending op. +enum class PipelineResult { NotApplicable, Created, Malformed }; + +// Read (cluster-name, priority) from a border marker op. Priority defaults +// to -1 when the marker doesn't carry the optional priority attribute. +static std::pair readBorderMarker(Operation *op) { + StringAttr clusterStr = + op->getAttrOfType("triton.warp_pipeline.border"); + int priority = -1; + if (auto intAttr = + op->getAttrOfType("triton.warp_pipeline.priority")) + priority = intAttr.getInt(); + return {clusterStr, priority}; +} + +// If `cluster` is empty, materialize a dummy SchedBarrier so the cluster is +// non-empty. This lets users deliberately request a pipeline bubble by +// emitting two consecutive border markers with no body between them. +static void addDummyOpIfEmptyCluster(OpBuilder &b, Location loc, + Operation *insertBefore, + SmallVectorImpl &cluster) { + if (!cluster.empty()) + return; + b.setInsertionPoint(insertBefore); + auto dummyOp = ROCDL::SchedBarrier::create(b, loc, 0); + dummyOp->setAttr("triton.warp_pipeline.empty_cluster", b.getUnitAttr()); + cluster.push_back(dummyOp); +} + // Create a scf.execute_region op representing a pipeline cluster. static void createClusterOp(OpBuilder &b, Location loc, SmallVector &ops, @@ -109,26 +163,10 @@ static void createClusterOp(OpBuilder &b, Location loc, return; } -// Ops that may appear before or after a stage but not inside one. -// Barrier/wait still require an explicit border op to split clusters. -static bool canSitBetweenStages(Operation *op) { - return isa(op); -} - -// Sink pure-scalar ops past adjacent ignorable ops so they join the next -// cluster. After loop unrolling, scalar IV-remap ops (arith.addi/muli) land -// between borders and the ignorable that starts the next iteration (FA -// pattern); without this, WarpPipeliner sees scalars as an incomplete -// cluster when it hits the ignorable and bails out. Single forward pass, -// O(N). -// -// `pending` only accumulates pure scalars and is cleared at any other op, -// so it forms a closed SSA DAG: any use of a pending scalar by the trailing -// ignorable run must be a direct operand, so the dependency check is a -// simple operand scan. -static void sinkPureScalarsPastIgnorables(Block &blk) { - // Accumulates a run of consecutive pure scalars that might be sunk. +// Move pure scalar IV-remap ops after adjacent inter-stage barriers/waits so +// they become part of the next stage. If a barrier/wait uses one of those +// scalars, leave the run in place to preserve SSA. +static void sinkPureScalarsIntoNextStage(Block &blk) { SmallVector pending; auto consumesPending = [&](Operation *user) { return llvm::any_of(user->getOperands(), [&](Value v) { @@ -142,15 +180,11 @@ static void sinkPureScalarsPastIgnorables(Block &blk) { op = next; continue; } - // Non-scalar op: try to sink pending past an ignorable run, then reset. if (canSitBetweenStages(op) && !pending.empty()) { - // Extend `anchor` to the last ignorable in the consecutive run. Operation *anchor = op; while (anchor->getNextNode() && canSitBetweenStages(anchor->getNextNode())) anchor = anchor->getNextNode(); - // Abort the sink if any ignorable in [op..anchor] consumes a pending - // scalar -- moving its producer past it would break SSA. bool conflict = false; for (Operation *ign = op; !conflict; ign = ign->getNextNode()) { conflict = consumesPending(ign); @@ -158,7 +192,6 @@ static void sinkPureScalarsPastIgnorables(Block &blk) { break; } if (!conflict) { - // Skip past the moved scalars on the next iteration. next = anchor->getNextNode(); // Reverse iteration + moveAfter(anchor) preserves source order: // each earlier-inserted scalar is pushed right by later inserts. @@ -166,75 +199,80 @@ static void sinkPureScalarsPastIgnorables(Block &blk) { s->moveAfter(anchor); } } - // Pending is always cleared at a non-scalar op: the run is broken, - // either by a successful sink or by an op that anchors them in place. pending.clear(); op = next; } } -// Turns a partitioned region into the warp-pipelined clusters -static LogicalResult createPipeline(OpBuilder &b, Location loc, - scf::ForOp forOp) { +// Turns a partitioned region into the warp-pipelined clusters. Returns +// NotApplicable when the loop has no border markers (user opted out), Created +// on success, or Malformed when border markers are present but the loop body +// cannot be split into a valid pipeline (an error is emitted in that case). +static PipelineResult createPipeline(OpBuilder &b, Location loc, + scf::ForOp forOp) { Block &blk = *forOp.getBody(); + + // Opt-in gate: if the loop body has no borders, the user did not request + // warp-pipelining for this loop and we must leave it untouched. + if (llvm::none_of(blk, [](Operation &op) { return isPipelineBorder(&op); })) + return PipelineResult::NotApplicable; + SmallVector cluster; SmallVector> clusterMarkers; SmallVector> clusters; auto ctx = forOp.getContext(); - auto isBorder = [](Operation *op) { - return op->hasAttr("triton.warp_pipeline.border"); - }; - - sinkPureScalarsPastIgnorables(blk); + sinkPureScalarsIntoNextStage(blk); // One pass over the body; collect clusters split by explicit borders. for (Operation &opRef : llvm::make_early_inc_range(blk)) { Operation *op = &opRef; - if (isBorder(op)) { // Wrap-up one cluster at a border. - StringAttr clusterStr = - op->getAttrOfType("triton.warp_pipeline.border"); - int priority = -1; - if (auto intAttr = - op->getAttrOfType("triton.warp_pipeline.priority")) { - priority = intAttr.getInt(); - } - clusterMarkers.push_back({clusterStr, priority}); - if (cluster.empty()) { - // This allows user to deliberately insert a pipeline bubble with a - // cluster only contains a dummy operation. - b.setInsertionPoint(op); - auto dummyOp = ROCDL::SchedBarrier::create(b, loc, 0); - dummyOp->setAttr("triton.warp_pipeline.empty_cluster", b.getUnitAttr()); - cluster.push_back(dummyOp); - } + if (isPipelineBorder(op)) { // Wrap up one cluster at a border. + clusterMarkers.push_back(readBorderMarker(op)); + addDummyOpIfEmptyCluster(b, loc, op, cluster); clusters.push_back(std::move(cluster)); cluster.clear(); - op->erase(); // remove the marker + op->erase(); // Remove the marker. continue; } if (canSitBetweenStages(op)) { - // Ignorable ops may appear before or after a stage, but not inside it. - // If encountered while building an execute_region, reject warp-pipeline. - if (!cluster.empty()) - return failure(); + // Barrier / async_wait family ops belong between stages, + // never inside one. Encountering one while a cluster is being built + // means the user inserted it inside a warp_pipeline_stage region. + if (!cluster.empty()) { + op->emitError("barrier or wait op cannot appear inside a " + "warp_pipeline_stage region"); + return PipelineResult::Malformed; + } continue; } - if (isa(op)) // End of the loop + if (isLoopOp(op)) { + // Loops are not permitted inside a stage; see isLoopOp for rationale. + op->emitError("loop op cannot appear inside a warp_pipeline_stage " + "region; to pipeline loop iterations, place " + "warp_pipeline_stage blocks inside the loop body"); + return PipelineResult::Malformed; + } + if (isa(op)) // End of the loop. break; - // Keep collecting ops for a cluster. + // Keep collecting ops for the current cluster. cluster.push_back(op); } - if (!cluster.empty()) { // create the last cluster if needed. + if (!cluster.empty()) { // Create the last cluster if needed. clusters.push_back(std::move(cluster)); auto clusterStr = StringAttr::get(ctx, "last_cluster"); clusterMarkers.push_back({clusterStr, -1}); } - // no pipeline clusters detected if 1 or 0 chunk found - if (clusters.size() < 2) - return failure(); + // We only reach here when at least one border existed; a single cluster + // means the borders are degenerate (e.g. a lone trailing border with no + // operations after it). Treat as malformed user input. + if (clusters.size() < 2) { + forOp->emitError( + "warp_pipeline_stage borders did not produce at least two stages"); + return PipelineResult::Malformed; + } // Materialize each cluster as an execute_region. int totalStages = clusters.size(); @@ -249,7 +287,114 @@ static LogicalResult createPipeline(OpBuilder &b, Location loc, forOp->setAttr("triton.warp_pipeline.pipelined_for", b.getUnitAttr()); LDBG("[warp-pipeline] total_stages=" << totalStages << "\n"); - return success(); + return PipelineResult::Created; +} + +// Create a pipelined region from flat (non-loop) border markers in a block. +// This handles the case where a loop was unrolled at the Python level +// (e.g. via static_range) but the body still has warp_pipeline_stage +// annotations producing border markers. The grouping logic mirrors +// createPipeline but without a loop wrapper. +// +// Returns NotApplicable when the block has no border markers, Created when +// a flat pipeline was materialized, or Malformed when borders are present +// but a valid pipeline could not be built (an error is emitted in that case). +static PipelineResult createFlatPipeline(OpBuilder &b, Block &block) { + // 1. Find all border markers in this block. + SmallVector allBorders; + for (auto &op : block) + if (isPipelineBorder(&op)) + allBorders.push_back(&op); + + // No borders at all means the block did not opt into flat pipelining. + if (allBorders.empty()) + return PipelineResult::NotApplicable; + + // A single border cannot form a 2-stage pipeline; treat as malformed input + // since the user did opt in (the lone border would otherwise leak through + // unprocessed). + if (allBorders.size() < 2) { + allBorders.front()->emitError( + "warp_pipeline_stage requires at least two borders to form a flat " + "pipeline"); + return PipelineResult::Malformed; + } + + Location loc = allBorders.front()->getLoc(); + Operation *firstBorder = allBorders.front(); + Operation *lastBorder = allBorders.back(); + + // 2. For flat pipelines, stage 0 may include the ops immediately before the + // first border. Stop at ops that must stay outside this pipeline. + Operation *regionStart = firstBorder; + for (Operation *op = firstBorder->getPrevNode(); op; op = op->getPrevNode()) { + if (isLoopOp(op) || isa(op) || + canSitBetweenStages(op)) + break; + regionStart = op; + } + + // 3. Sweep forward from regionStart, splitting ops into clusters at each + // border. Mirrors createPipeline's main loop, but bounded by lastBorder + // instead of scf.yield. + SmallVector cluster; + SmallVector> clusterMarkers; + SmallVector> clusters; + + for (auto it = Block::iterator(regionStart); it != block.end();) { + Operation *op = &*it; + ++it; + + if (isPipelineBorder(op)) { + clusterMarkers.push_back(readBorderMarker(op)); + addDummyOpIfEmptyCluster(b, loc, op, cluster); + clusters.push_back(std::move(cluster)); + cluster.clear(); + + bool isLast = (op == lastBorder); + op->erase(); + if (isLast) + break; + continue; + } + + if (canSitBetweenStages(op)) { + // Same rule as createPipeline: barriers/waits cannot live inside a + // stage. + if (!cluster.empty()) { + op->emitError("barrier or wait op cannot appear inside a " + "warp_pipeline_stage region"); + return PipelineResult::Malformed; + } + continue; + } + + if (isLoopOp(op)) { + // Same rule as createPipeline: loops cannot live inside a stage. + op->emitError("loop op cannot appear inside a warp_pipeline_stage " + "region; to pipeline loop iterations, place " + "warp_pipeline_stage blocks inside the loop body"); + return PipelineResult::Malformed; + } + + cluster.push_back(op); + } + + // 4. The bounded sweep should produce at least two clusters. + if (clusters.size() < 2) { + mlir::emitError( + loc, "warp_pipeline_stage borders did not produce at least two stages"); + return PipelineResult::Malformed; + } + + for (auto &&[stageOps, marker] : llvm::zip(clusters, clusterMarkers)) { + if (stageOps.empty()) + continue; + createClusterOp(b, loc, stageOps, marker); + } + + LDBG("[warp-pipeline] flat pipeline with " << clusters.size() << " stages"); + return PipelineResult::Created; } struct TritonAMDGPUWarpPipelinePass @@ -259,13 +404,39 @@ struct TritonAMDGPUWarpPipelinePass void runOnOperation() override { ModuleOp m = getOperation(); OpBuilder builder(m); + bool malformed = false; for (auto funcOp : m.getOps()) { funcOp.walk([&](scf::ForOp forOp) { Location loc = forOp.getLoc(); - if (createPipeline(builder, loc, forOp).failed()) - LDBG("Failed warp-pipelining"); + switch (createPipeline(builder, loc, forOp)) { + case PipelineResult::NotApplicable: + LDBG("scf.for has no warp_pipeline_stage borders; skipping"); + break; + case PipelineResult::Created: + break; + case PipelineResult::Malformed: + malformed = true; + break; + } }); + + // Process remaining border markers in flat (non-loop) code. Only the + // function's top-level blocks are visited; borders inside nested + // non-loop regions (e.g. scf.if bodies) are not handled here. + for (Block &block : funcOp.getBody()) { + switch (createFlatPipeline(builder, block)) { + case PipelineResult::NotApplicable: + break; + case PipelineResult::Created: + break; + case PipelineResult::Malformed: + malformed = true; + break; + } + } } + if (malformed) + signalPassFailure(); } }; diff --git a/third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py b/third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py index a4e1081e6e91..70e463c2c567 100644 --- a/third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py +++ b/third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py @@ -12,7 +12,6 @@ create_shared_layouts, create_tensor_descriptors, issue_loads, - issue_wmma, lds_load, issue_wmma_compute, ) @@ -22,7 +21,6 @@ create_shared_layouts, create_tensor_descriptors, issue_loads, - issue_wmma, lds_load, issue_wmma_compute, ) @@ -68,13 +66,13 @@ def gemm_tdm_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, # consumer = 0 accumulator = ttgl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.type.element_ty, layout=WMMA_LAYOUT) - # Triple buffering - # prefetch 2, the other one is overlapped. - for _ in ttgl.static_range(2): + # Prefetch NUM_BUFFERS - 1 tiles; the main loop produces one tile for + # each tile it consumes, and the epilogue drains the prefetched tail. + for _ in ttgl.static_range(NUM_BUFFERS - 1): producer = issue_loads(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS, TRANSPOSE_B) # Wait for the first prefetch - ttgl.amd.gfx1250.tdm.async_wait(1 * 2) + ttgl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 2) * 2) for _ in range(0, ttgl.cdiv(K, BLOCK_K) - (NUM_BUFFERS - 1)): with ttgl.amd.warp_pipeline_stage("stage0", priority=1): consumer, a, b = lds_load(consumer, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS, @@ -87,10 +85,12 @@ def gemm_tdm_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, # accumulator = issue_wmma_compute(a, b, accumulator) for i in ttgl.static_range(NUM_BUFFERS - 1): - # Warp-pipeline ended, wait for the ones to be consumed here. + with ttgl.amd.warp_pipeline_stage("stage0_epilogue", priority=1): + consumer, a, b = lds_load(consumer, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS, + TRANSPOSE_B) ttgl.amd.gfx1250.tdm.async_wait((NUM_BUFFERS - 1 - i) * 2) - consumer, accumulator = issue_wmma(consumer, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, - accumulator, (NUM_BUFFERS - 2 - i) * 2, NUM_BUFFERS, TRANSPOSE_B) + with ttgl.amd.warp_pipeline_stage("stage1_epilogue", priority=0): + accumulator = issue_wmma_compute(a, b, accumulator) offs_cm = pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, WMMA_LAYOUT)) offs_cn = pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, WMMA_LAYOUT))