From d3081d4006dc9f8b82196554b53ce7c1133e866e Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 27 Nov 2024 20:33:46 -0500 Subject: [PATCH 1/2] Update --- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 8 ++++-- test/TritonGPU/accelerate-matmul.mlir | 26 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index e5ccc5175840..3b29f73e1d7a 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -113,8 +113,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) + // Contains a chained dot. We prefer to assign warps to one axis + // to facilitate use cases like flash attention, allowing reductions within + // the same warp. + if (llvm::find_if(slices, [](Operation *op) { + return op->hasTrait(); + }) != slices.end()) return {(unsigned)numWarps, 1}; // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 15704bd0c216..52db5349beb4 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -73,6 +73,32 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 8]}> +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: chained_dot + tt.func public @chained_dot_wgmma( + %arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> + %r = tt.dot %c, %arg2, %cst_1 : + tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tt.return %r : tensor<64x128xf32, #blocked1> + } +} + +// ----- + // CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> From 003f56c7c7abd53cb926e245cc72d160ea548631 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 27 Nov 2024 21:00:51 -0500 Subject: [PATCH 2/2] Update --- test/Conversion/tritongpu_to_llvm.mlir | 4 +-- test/TritonGPU/accelerate-matmul.mlir | 7 ++--- test/TritonGPU/coalesce-async-copy.mlir | 36 ++++++++++++------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 4b8813516a08..2c53279b5cf6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1901,8 +1901,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- // CHECK: inline_asm_pack -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { // CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 52db5349beb4..17180a392440 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -73,7 +73,8 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 8]}> +// CHECK: #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +// CHECK: #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> @@ -85,12 +86,12 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %arg2: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> - // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma> %d = tt.dot %arg0, %arg1, %cst_0 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> %c = ttg.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> - // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1> %r = tt.dot %c, %arg2, %cst_1 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> diff --git a/test/TritonGPU/coalesce-async-copy.mlir b/test/TritonGPU/coalesce-async-copy.mlir index 4707ddaca9cb..0190238da135 100644 --- a/test/TritonGPU/coalesce-async-copy.mlir +++ b/test/TritonGPU/coalesce-async-copy.mlir @@ -1,35 +1,35 @@ // RUN: triton-opt %s -split-input-file -tritongpu-coalesce-async-copy | FileCheck %s -// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi1, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16xi8, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>, + %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>, %mask: tensor<64x16xi1, #blocked>, %other: tensor<64x16xi8, #blocked>) { - %token = triton_gpu.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> tt.return } } // ----- -// CHECK: #[[NEW_BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: %{{.*}} = triton_gpu.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -// CHECK: %{{.*}} = triton_gpu.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +// CHECK: #[[NEW_BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: %{{.*}} = ttg.convert_layout %{{.*}} : {{.*}} -> tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +// CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !triton_gpu.memdesc<64x16xi8, #shared, #triton_gpu.shared_memory, mutable>) { - %token = triton_gpu.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #triton_gpu.shared_memory, mutable> + %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>) { + %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> tt.return } }