From e4b15112495bff8a78fdd791d800c9f9debd5240 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Wed, 27 Nov 2024 19:00:07 +0000 Subject: [PATCH 1/3] [BACKEND] Fix inline asm bug for multiple packed <32bit output --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 5 ++--- test/Conversion/tritongpu_to_llvm.mlir | 13 +++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 632ccf10848c..5869ab36f08b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -325,13 +325,12 @@ struct ElementwiseInlineAsmOpConversion // asmResults is a flat struct; pack its values into // [return_value][op.getPackedElement()]. SmallVector> ret(op->getNumResults()); + int structIdx = 0; for (int i = 0; i < op->getNumResults(); i++) { - int structIdx = 0; for (int j = 0; j < op.getPackedElement(); j++) { Value val; if (asmRetTypes.size() > 1) { - val = - extract_val(asmResults, i * op.getPackedElement() + structIdx++); + val = extract_val(asmResults, structIdx++); } else { val = asmResults; } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3f2fd578da82..ae1d67100304 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1897,3 +1897,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK: inline_asm_pack +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 + tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { + %83:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %80 : tensor<64x64xi8, #blocked> -> tensor<64x64xbf16, #blocked>, tensor<64x64xbf16, #blocked> + tt.return + } +} + From 5705e0892f5dc1a45d9675581eb649fee3af30f6 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Wed, 27 Nov 2024 19:07:41 +0000 Subject: [PATCH 2/3] Check for output signature in test --- test/Conversion/tritongpu_to_llvm.mlir | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ae1d67100304..ef3857a790b9 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1905,6 +1905,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // check specifically for the case where asm has two results, pack > 1, and the result bitwidth is < 32 tt.func public @inline_asm_pack(%80: tensor<64x64xi8, #blocked>) attributes {noinline = false} { + // CHECK: llvm.inline_asm asm_dialect {{.*}} (vector<4xi8>) -> !llvm.struct<(vector<2xbf16>, vector<2xbf16>, vector<2xbf16>, vector<2xbf16>)> %83:2 = tt.elementwise_inline_asm "" {constraints = "=r,=r,=r,=r,r", packed_element = 4 : i32, pure = true} %80 : tensor<64x64xi8, #blocked> -> tensor<64x64xbf16, #blocked>, tensor<64x64xbf16, #blocked> tt.return } From 71d528b581f64c4c43b6c4532d1bf47a0afbb076 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Wed, 27 Nov 2024 19:28:07 +0000 Subject: [PATCH 3/3] Run pre-commit --- test/Conversion/tritongpu_to_llvm.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ef3857a790b9..07681e6f7b94 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1910,4 +1910,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } -