diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index b59e66f7b1fd..20be929ae7b4 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -59,6 +59,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUStreamPipelineV2(); // TODO: register Triton & TritonGPU passes registry.insert> &replacements); +// Append the given |newOperands| to the |forOp|'s yield op. +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands); + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index e18d9312daa8..eb15f03bda91 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -51,18 +51,6 @@ struct LoadInfo { } // namespace -// Replace the ForOp's yield with a new one with the given operands appended. -static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. - Operation *yieldOp = forOp.getBody()->getTerminator(); - SmallVector operands(yieldOp->getOperands()); - operands.append(newOperands.begin(), newOperands.end()); - - OpBuilder builder(yieldOp); - builder.create(yieldOp->getLoc(), operands); - yieldOp->erase(); -} - static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, tt::CoarseSchedule &schedule, @@ -1041,7 +1029,7 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, if (phase) newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToYield(forOp, newYieldOperands); + appendToForOpYield(forOp, newYieldOperands); return allocs; } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index be4e486a248b..eaf0a7e2a148 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -627,6 +627,16 @@ scf::IfOp replaceIfOpWithNewSignature( return newIf; } +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); diff --git a/test/TritonGPU/amd/amd-stream-pipeline.mlir b/test/TritonGPU/amd/amd-stream-pipeline.mlir deleted file mode 100644 index 4b2de3336413..000000000000 --- a/test/TritonGPU/amd/amd-stream-pipeline.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-stream-pipeline | FileCheck %s - -// CHECK-LABEL: @check_stream_pipeline_epilogue -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @check_stream_pipeline_epilogue(%Aptr: tensor<32x32x!tt.ptr, #blocked>, %Bptr : tensor<32x32x!tt.ptr, #blocked>, %arg4 : i32, %arg5 : i1) { - %cst_0 = arith.constant dense<16> : tensor<32x32xi32, #blocked> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - // CHECK: scf.for {{.*}} = %[[LB:.*]] to %[[UB:.*]] step %[[STEP:.*]] iter_args({{.*}}) - %36:3 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst_5, %arg12 = %Aptr, %arg13 = %Bptr) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { - %61 = arith.muli %arg9, %arg4 : i32 - %62 = arith.cmpi slt, %arg4, %61 : i32 - %63 = tt.splat %62 : i1 -> tensor<32x32xi1, #blocked> - // This load will not be pipelined - %66 = tt.load %arg12, %63 : tensor<32x32x!tt.ptr, #blocked> - // This load will be pipelined - %70 = tt.load %arg13 : tensor<32x32x!tt.ptr, #blocked> - %71 = triton_gpu.convert_layout %66 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.convert_layout %70 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %73 = tt.dot %71, %72, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - // This scf.if will make load at %66 non-pipelineable - %74 = scf.if %arg5 -> (tensor<32x32xf32, #blocked>){ - scf.yield %66 : tensor<32x32xf32, #blocked> - } else { - scf.yield %cst_2: tensor<32x32xf32, #blocked> - } - %75 = tt.addptr %arg12, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %76 = tt.addptr %arg13, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - scf.yield %73, %75, %76 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> - } - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[t0:.*]] = arith.subi %[[UB:.*]], %[[C1]] - // CHECK: %[[t1:.*]] = arith.subi %[[t0]], %[[LB]] - // CHECK: %[[t2:.*]] = arith.divui %[[t1]], %[[STEP]] - // CHECK: %[[t3:.*]] = arith.muli %[[t2]], %[[STEP]] - // CHECK: %[[PPLUB:.*]] = arith.addi %[[LB]], %[[t3]] - // CHECK: arith.muli %[[PPLUB]], {{.*}} - tt.return - } -} diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir new file mode 100644 index 000000000000..657da5f31346 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -0,0 +1,161 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: triton_gpu.local_store + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: tt.load + // CHECK: triton_gpu.local_store + // CHECK: scf.yield + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 15c3dadd0dff..f7a1e8127cbf 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,5 +1,6 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK // RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -55,6 +56,44 @@ // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] + +// AMD-LABEL: tt.func @matmul_loop +// AMD: %[[LOCAL_ALLOC_10:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_11:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_12:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_12]] +// AMD: %[[LOAD_14:.*]] = tt.load %{{.*}}, %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_12]] +// AMD: %[[LOAD_16:.*]] = tt.load %{{.*}}, %[[SPLAT_15]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %{{.*}}:6 = scf.for %[[ARG5:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]]) +// AMD: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_20]] +// AMD: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_27]], %{{.*}} +// AMD: %[[DOT_30:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[MULF_29]], %[[ARG8]] +// AMD: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_21]] +// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_33]] +// AMD: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_21]] +// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_35]], %{{.*}} +// AMD: %[[ADDI_37:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// AMD: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_10]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_40]] +// AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_11]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]] +// AMD: scf.yield %[[ADDPTR_31]], %[[ADDPTR_32]], %[[DOT_30]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_10]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_11]] + module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -146,6 +185,28 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local // CHECK scf.yield + +// AMD-LABEL: tt.func @matmul_loop_nested +// AMD: scf.for +// AMD-COUNT-2: triton_gpu.local_alloc +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[FOR:.*]]:6 = scf.for +// AMD-COUNT-2: triton_gpu.local_load +// AMD: tt.dot +// AMD-COUNT-2: tt.addptr +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: scf.yield +// AMD-COUNT-2: triton_gpu.local_dealloc +// AMD: scf.yield %[[FOR]]#2 + tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -216,6 +277,32 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] + +// AMD-LABEL: tt.func @matmul_loop_single_pipeline +// AMD: %[[LOAD_10:.*]] = tt.load +// AMD: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %[[MEMDESC_SUBVIEW_16]]) +// AMD: %[[SUBI_18:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_19:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_18]] +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[DOT_25:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_23]], %[[ARG7]] +// AMD: %[[ADDPTR_26:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_19]] +// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]], %{{.*}} +// AMD: %[[ADDI_29:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_30:.*]] = arith.cmpi slt, %[[ADDI_29]], %{{.*}} +// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_30]], %[[ADDI_29]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_31]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_32]] +// AMD: scf.yield %[[ADDPTR_26]], %[[DOT_25]], %[[SELECT_31]], %[[MEMDESC_SUBVIEW_32]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] + tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -268,6 +355,62 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} + +// AMD-LABEL: tt.func @indirect_bmm_scalar +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] +// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] +// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] +// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] +// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] +// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] +// AMD: %{{.*}}:8 = scf.for %[[ARG6:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG14:.*]] = %[[LOAD_15]], %[[ARG15:.*]] = %[[LOAD_21]]) +// AMD: %[[SUBI_25:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_26:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_25]] +// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_34:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %[[ARG7]] +// AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_26]] +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]], %[[SPLAT_37]] +// AMD: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_36]], %[[CMPI_26]] +// AMD: %[[MULI_40:.*]] = arith.muli %{{.*}}, %[[LOAD_39]] +// AMD: %[[SPLAT_41:.*]] = tt.splat %[[MULI_40]] +// AMD: %[[ADDPTR_42:.*]] = tt.addptr %{{.*}}, %[[SPLAT_41]] +// AMD: %[[SPLAT_43:.*]] = tt.splat %[[CMPI_26]] +// AMD: %[[LOAD_44:.*]] = tt.load %[[ADDPTR_42]], %[[SPLAT_43]] +// AMD: %[[ADDI_45:.*]] = arith.addi %[[ARG11]], %{{.*}} +// AMD: %[[CMPI_46:.*]] = arith.cmpi slt, %[[ADDI_45]], %{{.*}} +// AMD: %[[SELECT_47:.*]] = arith.select %[[CMPI_46]], %[[ADDI_45]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_47]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_48]] +// AMD: %[[MEMDESC_SUBVIEW_49:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_47]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG15]], %[[MEMDESC_SUBVIEW_49]] +// AMD: scf.yield %[[DOT_34]], %[[ADDPTR_35]], %[[ADDPTR_36]], %[[SELECT_47]], %[[MEMDESC_SUBVIEW_48]], %[[MEMDESC_SUBVIEW_49]], %[[LOAD_38]], %[[LOAD_44]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -293,7 +436,7 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } @@ -313,6 +456,54 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] + +// AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one +// AMD: %[[LOAD_0:.*]] = tt.load %{{.*}} +// AMD: %[[ADDPTR_1:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_ALLOC_2:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_3:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_4:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_5:.*]] = tt.splat %[[CMPI_4]] +// AMD: %[[LOAD_6:.*]] = tt.load %{{.*}}, %[[SPLAT_5]] +// AMD: %[[LOAD_7:.*]] = tt.load %[[ADDPTR_1]], %[[CMPI_4]] +// AMD: %[[MULI_8:.*]] = arith.muli %{{.*}}, %[[LOAD_0]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[MULI_8]] +// AMD: %[[ADDPTR_10:.*]] = tt.addptr %{{.*}}, %[[SPLAT_9]] +// AMD: %[[SPLAT_11:.*]] = tt.splat %[[CMPI_4]] +// AMD: %[[LOAD_12:.*]] = tt.load %[[ADDPTR_10]], %[[SPLAT_11]] +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %[[ADDPTR_1]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_14:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_6]], %[[MEMDESC_SUBVIEW_14]] +// AMD: %[[MEMDESC_SUBVIEW_15:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_12]], %[[MEMDESC_SUBVIEW_15]] +// AMD: %{{.*}}:7 = scf.for %[[ARG6:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %[[LOAD_7]], %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_14]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_15]]) +// AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_18:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_17]] +// AMD: %[[LOCAL_LOAD_22:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %[[ARG14]] +// AMD: %[[DOT_26:.*]] = tt.dot %[[LOCAL_LOAD_22]], %[[LOCAL_LOAD_23]], %[[ARG7]] +// AMD: %[[ADDPTR_27:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[SPLAT_28:.*]] = tt.splat %[[CMPI_18]] +// AMD: %[[LOAD_29:.*]] = tt.load %[[ADDPTR_27]], %[[SPLAT_28]] +// AMD: %[[LOAD_30:.*]] = tt.load %[[ARG9]], %[[CMPI_18]] +// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[ARG10]] +// AMD: %[[SPLAT_32:.*]] = tt.splat %[[MULI_31]] +// AMD: %[[ADDPTR_33:.*]] = tt.addptr %{{.*}}, %[[SPLAT_32]] +// AMD: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_18]] +// AMD: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_34]] +// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[ADDI_37:.*]] = arith.addi %[[ARG12]], %{{.*}} +// AMD: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// AMD: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_2]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_29]], %[[MEMDESC_SUBVIEW_40]] +// AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_3]][%[[SELECT_39]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_41]] +// AMD: scf.yield %[[DOT_26]], %[[ADDPTR_27]], %[[ADDPTR_36]], %[[LOAD_30]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_2]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_3]] + tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -365,6 +556,61 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield + +// AMD-LABEL: tt.func @indirect_bmm_vector +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[CMPI_5:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] +// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] +// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %{{.*}}:7 = scf.for %[[ARG6:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG14:.*]] = %[[LOAD_16]]) +// AMD: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// AMD: %[[SUBI_22:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_22]] +// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[LOCAL_LOAD_28]], %[[ARG7]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_23]] +// AMD: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_34]] +// AMD: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// AMD: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// AMD: %[[MULI_38:.*]] = arith.muli %{{.*}}, %[[BROADCAST_37]] +// AMD: %[[ADDPTR_39:.*]] = tt.addptr %{{.*}}, %[[MULI_38]] +// AMD: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_23]] +// AMD: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// AMD: %[[SPLAT_42:.*]] = tt.splat %[[CMPI_21]] +// AMD: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_42]] +// AMD: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// AMD: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_47:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_47]] +// AMD: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_41]], %[[MEMDESC_SUBVIEW_48]] +// AMD: scf.yield %[[DOT_31]], %[[ADDPTR_32]], %[[ADDPTR_33]], %[[SELECT_46]], %[[MEMDESC_SUBVIEW_47]], %[[MEMDESC_SUBVIEW_48]], %[[LOAD_43]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -392,16 +638,16 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK-LABEL: tt.func @post_load_inv -// CHECK: scf.for -// CHECK-DAG: %[[IV:.*]] = arith.index_cast -// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 -// CHECK: arith.index_cast -// CHECK-NOT: arith.addi %[[NEXT_IV]] +// COMMON-LABEL: tt.func @post_load_inv +// COMMON: scf.for +// COMMON-DAG: %[[IV:.*]] = arith.index_cast +// COMMON: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// COMMON: arith.index_cast +// COMMON-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -452,11 +698,12 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @cross_iter_dep +// COMMON-LABEL: tt.func @cross_iter_dep // TODO: enable pipelining with distance of 2 -// CHECK-NOT: triton_gpu.async_commit_group -// CHECK: scf.for -// CHECK: scf.yield +// COMMON-NOT: triton_gpu.async_commit_group +// COMMON: scf.for +// COMMON: scf.yield + tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -509,14 +756,14 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @dep_arg_two_uses -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims %arg5 -// CHECK-NEXT: tt.expand_dims %arg5 -// CHECK: %[[PTR0:.*]] = tt.splat %arg6 -// CHECK: %[[PTR1:.*]] = tt.addptr %[[PTR0]] -// CHECK-NEXT: tt.load %[[PTR1]] +// COMMON-LABEL: tt.func @dep_arg_two_uses +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims %arg5 +// COMMON-NEXT: tt.expand_dims %arg5 +// COMMON: %[[PTR0:.*]] = tt.splat %arg6 +// COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// COMMON-NEXT: tt.load %[[PTR1]] tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -583,7 +830,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, #shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { -// CHECK-LABEL: tt.func @load_two_users_incompatible_layouts +// COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> @@ -611,8 +858,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // check that the load didn't get pipelined. - // CHECK-NOT: alloc - // CHECK: scf.for + // COMMON-NOT: alloc + // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> @@ -644,6 +891,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.yield + +// AMD-LABEL: tt.func public @nested_loops +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: scf.yield +// AMD-DIS: scf.yield + // // The following code has the structure: // @@ -657,14 +913,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // } // ``` // -// Only the outer for should be pipelined. The regression this tests -// causes an assertion to fail while pipelining the outer `for`, in -// particular while predicating the operations scheduled to be emitted -// in the prologue. -// -// We check that there is no allocation before the first occurrence of -// scf.for because that would mean that the first load `%a = load()` -// would be pipelined. +// For CUDA, we pipeline the inner loop first then pipeline the outer +// loop to prefetch the async copy after the inner loop. +// For HIP, we only pipeline the inner loop for now. #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -735,6 +986,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// AMD-LABEL: tt.func @indirect_load_shared_layout +// AMD: %{{.*}}:7 = scf.for %[[ARG6:[a-z0-9]*]] = +// AMD-SAME: iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}, %[[ARG14:.*]] = %{{.*}}) +// AMD: %[[SUBI_20:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_20]] +// AMD: %[[SUBI_22:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %[[CMPI_23:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_22]] +// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[LOCAL_LOAD_28]], %[[ARG7]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_23]] +// AMD: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_34]] +// AMD: %[[EXPAND_DIMS_36:.*]] = tt.expand_dims %[[ARG14]] {axis = 1 : i32} +// AMD: %[[BROADCAST_37:.*]] = tt.broadcast %[[EXPAND_DIMS_36]] +// AMD: %[[MULI_38:.*]] = arith.muli %{{.*}}, %[[BROADCAST_37]] +// AMD: %[[ADDPTR_39:.*]] = tt.addptr %{{.*}}, %[[MULI_38]] +// AMD: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_23]] +// AMD: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]] +// AMD: %[[SPLAT_42:.*]] = tt.splat %[[CMPI_21]] +// AMD: %[[LOAD_43:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_42]] +// AMD: %[[ADDI_44:.*]] = arith.addi %[[ARG11]], %{{.*}} +// AMD: %[[CMPI_45:.*]] = arith.cmpi slt, %[[ADDI_44]], %{{.*}} +// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_45]], %[[ADDI_44]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_47:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_35]], %[[MEMDESC_SUBVIEW_47]] +// AMD: %[[MEMDESC_SUBVIEW_48:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_46]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_41]], %[[MEMDESC_SUBVIEW_48]] +// AMD: scf.yield %[[DOT_31]], %[[ADDPTR_32]], %[[ADDPTR_33]], %[[SELECT_46]], %[[MEMDESC_SUBVIEW_47]], %[[MEMDESC_SUBVIEW_48]], %[[LOAD_43]] + #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> @@ -769,7 +1053,7 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -784,6 +1068,16 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.memdesc_subview // CHECK: tt.return + +// AMD-LABEL: @kernel_yield_constant +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: scf.for +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -840,6 +1134,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[B1BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for + +// AMD-LABEL: tt.func public @add_kernel +// AMD: %[[LOAD_11:.*]] = tt.load %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOAD_13:.*]] = tt.load %[[ADDPTR_12]], %{{.*}} +// AMD: %[[ADDI_14:.*]] = arith.addi %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[ADDI_14]] +// AMD: %[[ADDI_16:.*]] = arith.addi %[[SPLAT_15]], %{{.*}} +// AMD: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_16]], %{{.*}} +// AMD: %[[ADDPTR_18:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[CMPI_17]] +// AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] +// AMD: scf.for #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { @@ -865,7 +1173,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> - }{tt.num_stages = 3 : i32} + } {tt.num_stages = 3 : i32} tt.return } } @@ -906,6 +1214,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] // CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] // CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] + +// AMD-LABEL: tt.func public @nested_loops +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_load +// AMD: tt.dot +// AMD: triton_gpu.local_store +// AMD: scf.yield +// AMD: triton_gpu.local_dealloc #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> @@ -1019,7 +1338,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // This test triggered some failure in the verifier, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @load_convert_layout +// COMMON-LABEL: @load_convert_layout #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> @@ -1060,7 +1379,7 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -1070,7 +1389,7 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 // This test captured some ICE in MatmulLoopPipeline pass, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @matmul_indirect_pipeline +// COMMON-LABEL: @matmul_indirect_pipeline #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { @@ -1106,15 +1425,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> - } + } {tt.num_stages = 3 : i32} tt.return } } // ----- -// CHECK-LABEL: @dont_pipeline_128x1 -// CHECK-NOT: local_load{{.*}}128x1 +// COMMON-LABEL: @dont_pipeline_128x1 +// COMMON-NOT: local_load{{.*}}128x1 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { @@ -1156,8 +1475,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that the dependencies across ops of different nesting does not cause crash or // incorrect schedule that fails to pipeline. -// CHECK-LABEL: @matmul_nested_ops -// CHECK: triton_gpu.local_load +// COMMON-LABEL: @matmul_nested_ops +// COMMON: triton_gpu.local_load #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1227,8 +1546,8 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: dot_prologue_epilogue - // CHECK: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + // COMMON-LABEL: dot_prologue_epilogue + // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> @@ -1251,17 +1570,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK: scf.for %[[IND_VAR:.*]] = %[[C0]] - // CHECK-NOT load - // CHECK: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] - // CHECK: scf.if %[[CND]] - // CHECK: dot - // CHECK: scf.if %[[CND]] - // CHECK: arith.mulf - // CHECK: scf.yield - // CHECK-NOT: tt.addptr - // CHECK: scf.yield + // COMMON: %[[C0:.*]] = arith.constant 0 : i32 + // COMMON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // COMMON-NOT: load + // COMMON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // COMMON: scf.if %[[CND]] + // COMMON: dot + // COMMON: scf.if %[[CND]] + // COMMON: arith.mulf + // COMMON: scf.yield + // COMMON-NOT: tt.addptr + // COMMON: scf.yield %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %cnd = arith.cmpi slt, %arg3, %ext : i32 @@ -1365,6 +1684,18 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[B:.*]] = triton_gpu.local_load // CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] +// AMD-LABEL: @masked_add_kernel +// AMD: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: scf.for +// AMD: arith.select +// AMD: arith.addf +// AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] + #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index e7a9753b2145..293ee924f05e 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -8,6 +8,8 @@ namespace mlir { std::unique_ptr createTritonAMDGPUStreamPipelinePass(); +std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); + std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), int matrixInstructionSize = 0, diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index a818b1ac9da5..5a6df1827fe4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -16,6 +16,25 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod let dependentDialects = []; } +def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Pipeline global loads through registers to shared memory while computing on previous + tile + }]; + + let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; + + let dependentDialects = []; + + let options = [ + Option<"numStages", "num_stages", + "int32_t", /*default*/"2", + "Number of Pipeline stages"> + ]; +} + def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { let summary = "accelerate matmul"; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index d96860c3ef90..5bacc66a1161 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonAMDGPUTransforms OptimizeEpilogue.cpp ReorderInstructions.cpp StreamPipeline.cpp + StreamPipelineV2.cpp MfmaGroup.cpp DEPENDS diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp new file mode 100644 index 000000000000..a785cfd2ffec --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -0,0 +1,675 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; +}; + +} // namespace + +static void createStreamCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value extractIdx, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } + + loadOp->replaceAllUsesWith(result); + + // Prefetch load ahead of the dot stage if is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + assert(numStages >= 2 && "requires num_stages=2 at least"); + schedule.insert(storeOp, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); + } + loadOp.erase(); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visisted load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + LoadInfo loadInfo; + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << *loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + // Limit shared memory sharing to width >= 32 elements. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) { + LDBG("Skip width<32 load " << *loadOp); + continue; + } + + if (use->hasTrait()) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + loadToInfo[op] = loadInfo; + } + + return loadToInfo; +} + +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + if (loadToInfo.empty()) + return {}; + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Put the root uses of the loads in the last stage. + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + // Create a cluster for load ops at each indirection level. + SmallVector loadsClusters; + for (int i = 0; i <= maxIndirectionLevel; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); + return loadToInfo; +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + tt::CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +static void +scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + return builder.create(loadOp->getLoc(), memdescType, + Value()); +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +static SmallVector +createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + int numStages) { + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + int numBuffers = -1; + for (auto &[_, info] : loadToInfo) + numBuffers = std::max(numBuffers, info.distToUse); + LDBG("deduced shared memory buffer number = " << numBuffers); + + SmallVector allocs; + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding) + continue; + + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, {extractIdx}); + forOp.erase(); + forOp = newForOp; + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = newForOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Create a cluster for prefetching global reads for the dot. + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) + createStreamCopy(forOp, loadOp, alloc, extractIdx, schedule, + prefetchCluster, loadToInfo, numStages); + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); + + return allocs; +} + +static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages, + tt::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + tt::CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + SmallVector allocs = + createStreamOps(forOp, coarseSchedule, loadToInfo, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with stream loads:"); + coarseSchedule.dump(); + }); + + tt::CoarseSchedule::Cluster afterPrologue = coarseSchedule.clusters.begin(); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp, std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; +} + +// Return true if the preconditions for pipelining the loop are met. +static bool checkPrecondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { return !operand.getDefiningOp(); })) + return false; + + // Don't pipeline outer loops. + auto hasNestedLoopInside = [forOp](Operation *op) { + if (op != forOp && isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }; + return !forOp->walk(hasNestedLoopInside).wasInterrupted(); +} + +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + if (!checkPrecondition(forOp)) + return false; + + tt::PipeliningOption options; + if (!preprocessLoopAndBuildSchedule(forOp, numStages, options)) + return false; + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); +} + +namespace { +struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { + PipelinePass() = default; + PipelinePass(int32_t numStages) { this->numStages = numStages; } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) + pipelineLoop(forOp, getNumStagesOrDefault(forOp)); + } + +private: + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists, otherwise use the + // global control. + if (auto attr = forOp->getAttrOfType(tt::kNumStagesAttrName)) + return attr.getInt(); + return numStages; + } +}; +} // anonymous namespace + +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { + return std::make_unique(numStages); +} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index bee5437555f4..a6ef2fec7c67 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -60,6 +60,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass); + ADD_PASS_WRAPPER_1("add_stream_pipelinev2", + mlir::createTritonAMDGPUStreamPipelineV2Pass, int); } void addControlConstant(llvm::Module *module, const char *name,