From 5973989de87606055afbfb3116ebdd22931b95f9 Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Tue, 5 Nov 2024 20:30:07 -0800 Subject: [PATCH 1/4] [AMD] Fix issue with rank=1 in tryFitCvtIntoLDS --- .../amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index db3223f119da..1d7c6545cf61 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -96,7 +96,15 @@ class OptimizeAMDLDSUsage auto dstEnc = dstType.getEncoding(); auto ctx = srcEnc.getContext(); - auto rank = srcType.getShape().size(); + auto rank = srcType.getRank(); + + // TODO: This function currently crashes with rank == 1 and seems to assume + // rank == 2. + // This should probably be generalized to other ranks, but for now + // check the rank. + if (rank != 2) + return; + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); auto warpSize = triton::gpu::getWarpSize(srcEnc); @@ -109,11 +117,14 @@ class OptimizeAMDLDSUsage // Create a list of temporary layouts SmallVector elemsPerThread(rank, 1); SmallVector threadsPerWarp(rank, 1); + threadsPerWarp[rank - 1] = warpSize / 8; threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + auto layoutCTA = triton::gpu::getCTALayout(srcEnc); auto order = triton::gpu::getOrder(srcEnc); SmallVector dummyWarpsPerCTA(rank, 1); + auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, layoutCTA); From 3a8f97bfd8e57c2bc8afaca03edd95c0cc8ffd1d Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Wed, 6 Nov 2024 07:15:41 -0800 Subject: [PATCH 2/4] a better fix, passes tests --- .../amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp | 11 +++-------- .../amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp | 6 ++++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index 1d7c6545cf61..be551aca99eb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -98,13 +98,6 @@ class OptimizeAMDLDSUsage auto ctx = srcEnc.getContext(); auto rank = srcType.getRank(); - // TODO: This function currently crashes with rank == 1 and seems to assume - // rank == 2. - // This should probably be generalized to other ranks, but for now - // check the rank. - if (rank != 2) - return; - unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); auto warpSize = triton::gpu::getWarpSize(srcEnc); @@ -119,7 +112,9 @@ class OptimizeAMDLDSUsage SmallVector threadsPerWarp(rank, 1); threadsPerWarp[rank - 1] = warpSize / 8; - threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + // Skip this if the rank is 1 + if (rank > 1) + threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; auto layoutCTA = triton::gpu::getCTALayout(srcEnc); auto order = triton::gpu::getOrder(srcEnc); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp index dfa8e06e247c..05a618f690ed 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -68,9 +68,11 @@ Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), src.getKWidth()); } - if (auto src = dyn_cast(layout)) + if (auto src = dyn_cast(layout)) { + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); return triton::gpu::SliceEncodingAttr::get( - ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA)); + ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); + } assert("Encountered unsupported layout"); return Attribute(); } From f1b904c8d1aebb4cf6472cfd0f8af8000255c5eb Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Wed, 6 Nov 2024 14:17:02 -0800 Subject: [PATCH 3/4] add lit test and incorporate feedback --- test/TritonGPU/amd/optimize-lds-usage.mlir | 17 +++++++++++++++++ .../lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp | 10 +++++++--- .../TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp | 1 + 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index 0ec2ad5382ec..b670e5c70a93 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -118,3 +118,20 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war tt.return } } + +// Checks that optimization do not crash on 1d tensor +// CHECK-LABEL: convert_1d +// CHECK: triton_gpu.local_alloc +// CHECK-NEXT: triton_gpu.convert_layout +// CHECK-NEXT: triton_gpu.local_load +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> + %1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index be551aca99eb..4a0a7fed22b0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -111,10 +111,14 @@ class OptimizeAMDLDSUsage SmallVector elemsPerThread(rank, 1); SmallVector threadsPerWarp(rank, 1); - threadsPerWarp[rank - 1] = warpSize / 8; - // Skip this if the rank is 1 - if (rank > 1) + // Special case for rank == 1 + if (rank == 1) { + threadsPerWarp[0] = warpSize; + } else { + assert(rank > 1); + threadsPerWarp[rank - 1] = warpSize / 8; threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + } auto layoutCTA = triton::gpu::getCTALayout(srcEnc); auto order = triton::gpu::getOrder(srcEnc); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp index 05a618f690ed..f207960f4cb8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -69,6 +69,7 @@ Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { src.getKWidth()); } if (auto src = dyn_cast(layout)) { + // TODO: think of a way to construct slice layouts based on warpsPerCTA argument auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); return triton::gpu::SliceEncodingAttr::get( ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); From 6e49d28e2692f27e8cf26bc695b55d94732bfb3d Mon Sep 17 00:00:00 2001 From: Sam Ginzburg Date: Wed, 6 Nov 2024 15:16:34 -0800 Subject: [PATCH 4/4] address comments and add test --- test/TritonGPU/amd/optimize-lds-usage.mlir | 2 ++ third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index b670e5c70a93..5cd34aab27d2 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -119,6 +119,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war } } +// ----- + // Checks that optimization do not crash on 1d tensor // CHECK-LABEL: convert_1d // CHECK: triton_gpu.local_alloc diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp index f207960f4cb8..fb0bfb656ef4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -69,7 +69,8 @@ Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { src.getKWidth()); } if (auto src = dyn_cast(layout)) { - // TODO: think of a way to construct slice layouts based on warpsPerCTA argument + // TODO: think of a way to construct slice layouts based on warpsPerCTA + // argument auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); return triton::gpu::SliceEncodingAttr::get( ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));