From 57a6988fd007d8ca4a5e61e81b865b345338617f Mon Sep 17 00:00:00 2001 From: Finlay Date: Tue, 28 May 2024 17:00:59 +0100 Subject: [PATCH 1/2] Remove redundant options from passes (#4015) The TritonGPUPipeline pass has unused pass options and the TritonGPUAccelerateMatmul pass option could instead be read from the module attributes, where the data already exists. The goal is to reduce redundancy. --------- Signed-off-by: Finlay Marno --- .../Dialect/TritonGPU/Transforms/Passes.td | 17 +------------- .../Dialect/TritonGPU/Transforms/Utility.h | 3 +++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 ++ lib/Dialect/TritonGPU/Transforms/Utility.cpp | 21 ++++++++++++++++++ python/src/passes.cc | 6 ++--- test/TritonGPU/accelerate-matmul.mlir | 22 +++++++------------ test/TritonGPU/loop-pipeline-hopper.mlir | 2 +- .../pipeline-hopper-remove-wait.mlir | 2 +- third_party/nvidia/backend/compiler.py | 4 ++-- 9 files changed, 41 insertions(+), 38 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index d98f918fde11..fdceb2cfe473 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -19,16 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { let options = [ Option<"numStages", "num-stages", "int32_t", /*default*/"3", - "number of pipeline stages">, - Option<"numWarps", "num-warps", - "int32_t", /*default*/"4", - "number of warps per block">, - Option<"numCTAs", "num-ctas", - "int32_t", /*default*/"1", - "number of CTAs per CGA">, - Option<"computeCapability", "compute-capability", - "int32_t", /*default*/"80", - "device compute capability"> + "number of pipeline stages"> ]; } @@ -68,12 +59,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", "mlir::triton::TritonDialect"]; - - let options = [ - Option<"computeCapability", "compute-capability", - "int32_t", /*default*/"80", - "device compute capability"> - ]; } def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index aa51bc5c4bc8..114c1814254c 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -169,6 +169,9 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, // operand and single result. bool isPureUnaryInlineAsm(Operation *op); +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 821c8ab9c5bd..df84c4e628ea 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -387,6 +387,8 @@ class TritonGPUAccelerateMatmulPass MLIRContext *context = &getContext(); ModuleOp m = getOperation(); + auto computeCapability = getNVIDIAComputeCapability(m); + mlir::RewritePatternSet patterns(context); patterns.add(context, computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index cc1818ca7b6e..1d6152417eef 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -8,6 +8,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -818,6 +819,26 @@ bool isPureUnaryInlineAsm(Operation *op) { inlineAsmOp.getPure(); } +int getNVIDIAComputeCapability(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/python/src/passes.cc b/python/src/passes.cc index ae48846104cc..513e811d28ad 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -49,11 +49,9 @@ void init_triton_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); ADD_PASS_WRAPPER_0("add_optimize_thread_locality", createTritonGPUOptimizeThreadLocality); - ADD_PASS_OPTION_WRAPPER_4("add_pipeline", createTritonGPUPipeline, int, int, - int, int); + ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int); ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); - ADD_PASS_OPTION_WRAPPER_1("add_accelerate_matmul", - createTritonGPUAccelerateMatmul, int); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_reorder_instructions", createTritonGPUReorderInstructions); ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 6536fb3f9d44..8c4e85aa0859 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,6 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=89 | FileCheck %s --check-prefix=CHECK-89 -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FileCheck %s --check-prefix=CHECK-80 +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> // CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> @@ -49,24 +47,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-80: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-80-LABEL: chained_dot + // CHECK-LABEL: chained_dot tt.func public @chained_dot( %arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> - // CHECK-80: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - // CHECK-80: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> %r = tt.dot %c, %arg2, %cst_1 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> @@ -75,18 +73,18 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- -// CHECK-89: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> +// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-89-LABEL: fp8_dot + // CHECK-LABEL: fp8_dot tt.func public @fp8_dot( %arg0: tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> { %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> - // CHECK-89: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> + // CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]> %d = tt.dot %arg0, %arg1, %cst_0 : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> tt.return %d : tensor<64x64xf32, #blocked> @@ -97,8 +95,6 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : // CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> // CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> -// CHECK-80-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> -// CHECK-80-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -112,7 +108,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %1 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> %2 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1> // CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> - // CHECK-80: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]> %3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1> %4 = triton_gpu.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2> @@ -122,7 +117,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %9 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> %10 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3> // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> - // CHECK-80: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> %12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> tt.print ": " {hex = false} : %12 : tensor<2x16x16xf32, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index b3dc9d883433..5c9cac004c20 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck --dump-input-context=50 %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index a3e002321f2c..1e3d4d96708b 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline -canonicalize | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index a9f389ad6162..10c5f3e83b02 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -168,13 +168,13 @@ def make_ttgir(mod, metadata, opt, capability): nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) - passes.ttgpuir.add_accelerate_matmul(pm, capability) + passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: passes.ttgpuir.add_combine_tensor_select_and_if(pm) - passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) + passes.ttgpuir.add_pipeline(pm, opt.num_stages) if capability // 10 <= 8: passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) From 780bf5d84be16d99bcf70a51109bef9a60026e08 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 29 May 2024 17:45:16 -0700 Subject: [PATCH 2/2] Run prefetch on mma-v2 dots (#4000) This will enable prefetching for mma-v2 dots on H100. --------- Co-authored-by: Manman Ren --- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 12 +++++++++++- third_party/nvidia/backend/compiler.py | 3 +-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index bb086e03ee22..85a95aaa7d5e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -151,10 +151,20 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, LogicalResult Prefetcher::initialize() { Block *loop = forOp.getBody(); + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + SmallVector dotsInFor; for (Operation &op : *loop) - if (auto dotOp = dyn_cast(op)) + if (auto dotOp = dyn_cast(op)) { + // bail out if there exist non v2 dots. + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 2) + return failure(); dotsInFor.push_back(dotOp); + } if (dotsInFor.empty()) return failure(); diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 10c5f3e83b02..6d7994923495 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -175,8 +175,7 @@ def make_ttgir(mod, metadata, opt, capability): if capability // 10 >= 8: passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages) - if capability // 10 <= 8: - passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm)