Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,7 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">,
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps per block">,
Option<"numCTAs", "num-ctas",
"int32_t", /*default*/"1",
"number of CTAs per CGA">,
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
"number of pipeline stages">
];
}

Expand Down Expand Up @@ -68,12 +59,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
"mlir::triton::TritonDialect"];

let options = [
Option<"computeCapability", "compute-capability",
"int32_t", /*default*/"80",
"device compute capability">
];
}

def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
// operand and single result.
bool isPureUnaryInlineAsm(Operation *op);

// read the compute capability from the module attributes
int getNVIDIAComputeCapability(Operation *module);

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ class TritonGPUAccelerateMatmulPass
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

auto computeCapability = getNVIDIAComputeCapability(m);

mlir::RewritePatternSet patterns(context);
patterns.add<BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,20 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
LogicalResult Prefetcher::initialize() {
Block *loop = forOp.getBody();

auto getEncoding = [](Value v) {
return cast<TensorOrMemDesc>(v.getType()).getEncoding();
};

SmallVector<triton::DotOp> dotsInFor;
for (Operation &op : *loop)
if (auto dotOp = dyn_cast<triton::DotOp>(op))
if (auto dotOp = dyn_cast<triton::DotOp>(op)) {
// bail out if there exist non v2 dots.
auto dstEnc =
dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult()));
if (!dstEnc || dstEnc.getVersionMajor() != 2)
return failure();
dotsInFor.push_back(dotOp);
}

if (dotsInFor.empty())
return failure();
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -818,6 +819,26 @@ bool isPureUnaryInlineAsm(Operation *op) {
inlineAsmOp.getPure();
}

int getNVIDIAComputeCapability(Operation *module) {
assert(module->hasAttr(triton::AttrTargetName) &&
"Expected a target attribute on the module operation");

StringAttr targetAttr =
cast<StringAttr>(module->getAttr(triton::AttrTargetName));

StringRef ref = targetAttr.strref();
assert(ref.starts_with("cuda:") &&
"expected target attribute to be prefixed with \"cuda:\"");

StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:"
int computeCapability;
bool parseError = capabilityStr.getAsInteger(10, computeCapability);
assert(!parseError &&
"invalid compute capability string in target attribute");

return computeCapability;
}

namespace {

/// Detect dead arguments in scf.for op by assuming all the values are dead and
Expand Down
6 changes: 2 additions & 4 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ void init_triton_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce);
ADD_PASS_WRAPPER_0("add_optimize_thread_locality",
createTritonGPUOptimizeThreadLocality);
ADD_PASS_OPTION_WRAPPER_4("add_pipeline", createTritonGPUPipeline, int, int,
int, int);
ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int);
ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch);
ADD_PASS_OPTION_WRAPPER_1("add_accelerate_matmul",
createTritonGPUAccelerateMatmul, int);
ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_reorder_instructions",
createTritonGPUReorderInstructions);
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC);
Expand Down
22 changes: 8 additions & 14 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=89 | FileCheck %s --check-prefix=CHECK-89
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FileCheck %s --check-prefix=CHECK-80
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s

// CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
// CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
Expand Down Expand Up @@ -49,24 +47,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK-80: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}>
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-80-LABEL: chained_dot
// CHECK-LABEL: chained_dot
tt.func public @chained_dot(
%arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1>
// CHECK-80: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]>
%d = tt.dot %arg0, %arg1, %cst_0 :
tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
%t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked>
%c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
// CHECK-80: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]>
%r = tt.dot %c, %arg2, %cst_1 :
tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1>
tt.return %r : tensor<64x128xf32, #blocked1>
Expand All @@ -75,18 +73,18 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :

// -----

// CHECK-89: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-89-LABEL: fp8_dot
// CHECK-LABEL: fp8_dot
tt.func public @fp8_dot(
%arg0: tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x64xf32, #blocked> {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
// CHECK-89: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]>
// CHECK: tt.dot {{.*}} : tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x64xf32, #[[$MMA]]>
%d = tt.dot %arg0, %arg1, %cst_0 :
tensor<64x128xf8E4M3FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
tt.return %d : tensor<64x64xf32, #blocked>
Expand All @@ -97,8 +95,6 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :

// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>
// CHECK-80-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
// CHECK-80-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
Expand All @@ -112,7 +108,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
%1 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>
%2 = triton_gpu.convert_layout %cst_0 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #blocked1>
// CHECK: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]>
// CHECK-80: tt.dot {{.*}} -> tensor<16x16xf32, #[[MMA]]>
%3 = tt.dot %0, %1, %2, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<16x16xf32, #blocked1>
%4 = triton_gpu.convert_layout %3 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16x16xf32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16x16xf32, #blocked2>
Expand All @@ -122,7 +117,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
%9 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>>
%10 = triton_gpu.convert_layout %cst : tensor<2x16x16xf32, #blocked> -> tensor<2x16x16xf32, #blocked3>
// CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]>
// CHECK-80: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]>
%11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3>
%12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked>
tt.print ": " {hex = false} : %12 : tensor<2x16x16xf32, #blocked>
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck --dump-input-context=50 %s
// RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s

// 4 warps
// matmul: 128x32 @ 32x128 -> 128x128
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/pipeline-hopper-remove-wait.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline=compute-capability=90 -canonicalize | FileCheck %s
// RUN: triton-opt %s -split-input-file -canonicalize -tritongpu-pipeline -canonicalize | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
Expand Down
7 changes: 3 additions & 4 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,14 @@ def make_ttgir(mod, metadata, opt, capability):
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm, capability)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.common.add_cse(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability)
if capability // 10 <= 8:
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down