diff --git a/README.md b/README.md index f916461d8c79..8258f7bedd94 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). +- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. # Changelog diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index ab7e96945d7c..39c043695bc6 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { for (int baseVersion : versionsSupported) { if (supportMMA(op, baseVersion)) return baseVersion; + if (baseVersion == 3) + op.emitRemark() << "Warning: can't use MMA V3 for the dot op"; } return 0; } diff --git a/python/src/ir.cc b/python/src/ir.cc index 129daccd1bba..70d52f36058d 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -10,6 +10,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" @@ -27,6 +28,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/SourceMgr.h" namespace { @@ -201,7 +203,16 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::class_(m, "context", py::module_local()).def(py::init<>()); + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }); + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; diff --git a/python/src/llvm.cc b/python/src/llvm.cc index f4c023f232e7..ef9ff80df6c7 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -18,6 +18,7 @@ #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Signals.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" @@ -150,6 +151,8 @@ void init_triton_llvm(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()); + py::class_(m, "source_mgr", py::module_local()) + .def(py::init<>()); py::class_(m, "function_list") .def( diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py new file mode 100644 index 000000000000..871bc6ba294b --- /dev/null +++ b/python/test/unit/test_perf_warning.py @@ -0,0 +1,47 @@ +import triton +import triton.language as tl +import os +import pytest +import torch + + +def is_perf_warning_enabled(): + return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1' + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def test_mma_remark(capfd): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 9: + pytest.skip("Requires sm >= 90 to run") + + os.environ['MLIR_ENABLE_REMARK'] = '1' + + @triton.jit + def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn): + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(32, 128), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(128, 32), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(32, 32), order=(1, 0)) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + c = tl.dot(a, b) + tl.store(c_block_ptr, c) + + triton.compile( + triton.compiler.ASTSource( + fn=matmul_kernel, signature={ + 0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9: + 'i32', 10: 'i32', 11: 'i32' + }, constants={})) + captured = capfd.readouterr() + + assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark" + assert "note: see current operation:" in captured.err + os.environ['MLIR_ENABLE_REMARK'] = '0' diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6ee7f2281e04..a4de87815647 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -158,6 +158,11 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimX = opt.cluster_dims[0] cluster_info.clusterDimY = opt.cluster_dims[1] cluster_info.clusterDimZ = opt.cluster_dims[2] + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug()