From 1d738e1a5b7be62c7c670a926609fc49ad3662c8 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Thu, 31 Oct 2024 03:10:47 +0000 Subject: [PATCH 1/4] Fix order for Hopper mmaLayout --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3b5316ecc0e3..639b9b243f52 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -288,7 +288,14 @@ SmallVector getOrder(Attribute layout) { auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); + auto nvidiaMma = dyn_cast(layout); + if (nvidiaMma && nvidiaMma.isHopper()) { + // Hopper WGMMA requires that the 4 warps in a warp group be laid out along + // the M-dimension; so here we rasterize M-first + std::iota(order.begin(), order.end(), 0); + } else { + std::iota(order.rbegin(), order.rend(), 0); + } return order; } if (auto dotLayout = dyn_cast(layout)) { From d8745897d5ae06e72efa9494e2ece97d48e86844 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 1 Nov 2024 00:34:38 +0000 Subject: [PATCH 2/4] Fix bad rebase: use mmaBitwidth to compute numRep --- .../ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 6094a911189d..5ec6851373d3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -659,7 +659,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int kWidth = encoding.getKWidth(); auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( - shapePerCTA, bitwidth, kWidth, encoding.getOpIdx()); + shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); From c2821a8f1db0030c32b8e8c8bd0f9116fadb7a24 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 1 Nov 2024 16:29:04 +0000 Subject: [PATCH 3/4] Revert getOrder and use getWarpOrder in SharedToDot --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 +-------- .../SharedToDotOperandMMAv2OrV3.cpp | 4 ++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 639b9b243f52..3b5316ecc0e3 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -288,14 +288,7 @@ SmallVector getOrder(Attribute layout) { auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); - auto nvidiaMma = dyn_cast(layout); - if (nvidiaMma && nvidiaMma.isHopper()) { - // Hopper WGMMA requires that the 4 warps in a warp group be laid out along - // the M-dimension; so here we rasterize M-first - std::iota(order.begin(), order.end(), 0); - } else { - std::iota(order.rbegin(), order.rend(), 0); - } + std::iota(order.rbegin(), order.rend(), 0); return order; } if (auto dotLayout = dyn_cast(layout)) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 5ec6851373d3..a26f2c93dc42 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -662,12 +662,12 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto order = triton::gpu::getOrder(mmaLayout); + auto warpOrder = encoding.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA, order); + delinearize(rewriter, loc, warp, warpsPerCTA, warpOrder); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); From 246dac8a2863a7d4ed4bc7c630833b936208d595 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 1 Nov 2024 20:02:57 +0000 Subject: [PATCH 4/4] Get mmaLayout's order instead of dotOp's --- .../ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index a26f2c93dc42..8f1fcc1f7035 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -662,7 +662,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto warpOrder = encoding.getWarpOrder(); + auto warpOrder = mmaLayout.getWarpOrder(); Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32));