diff --git a/README.md b/README.md index 0d5d08dde657..0f8da14ecd07 100644 --- a/README.md +++ b/README.md @@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi - `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. - `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma). -- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings). -- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks. +- `MLIR_ENABLE_DIAGNOSTICS=` controls diagnostic emission in MLIR. + Options are: `warnings`, `remarks`, `stacktraces`, `operations`. + Use comma-separated values to customize output. For example, + `MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations, + while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with + stacktraces. By default, only errors are shown. Setting `warnings` includes + errors and warnings; `remarks` includes errors, warnings, and remarks. +- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`. - `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn. - `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1. - `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage. diff --git a/python/src/ir.cc b/python/src/ir.cc index 53451b706ae1..b5411dd4281d 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -140,6 +140,42 @@ class TritonOpBuilder { bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); }; +// Run the pass manager under a source manager diagnostic handler, which +// enables emitted MLIR diagnostics to directly reference Python source +// code. This diagnostic handler supports filtering diagnostic info by +// severity levels. +struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler { + TritonSourceMgrDiagnosticHandler(MLIRContext *ctx, + DiagnosticSeverity minSeverity) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this, minSeverity](Diagnostic &diag) { + auto severity = diag.getSeverity(); + switch (severity) { + case DiagnosticSeverity::Error: + break; + case DiagnosticSeverity::Warning: + if (minSeverity == DiagnosticSeverity::Error) + return success(); + break; + case DiagnosticSeverity::Remark: + if (minSeverity == DiagnosticSeverity::Error || + minSeverity == DiagnosticSeverity::Warning) + return success(); + break; + case DiagnosticSeverity::Note: + // notes are handled somewhere else. + return failure(); + default: + llvm_unreachable("Unknown diagnostic severity"); + } + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; +}; + std::string locationToString(Location loc) { std::string str; llvm::raw_string_ostream os(str); @@ -148,6 +184,23 @@ std::string locationToString(Location loc) { return str; } +// Function to parse a comma-separated string into a vector of C-style strings +llvm::SmallVector +parseCommaSeparatedValues(const std::string &input, + llvm::SmallVector &storage) { + llvm::SmallVector split; + llvm::SmallVector result; + StringRef(input.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + return result; +} + void outputWarning(Location loc, const std::string &msg) { std::string locStr = locationToString(loc); @@ -1691,8 +1744,6 @@ void init_triton_ir(py::module &&m) { .def("enable_debug", [](PassManager &self) { auto *context = self.getContext(); - bool haveDiagnostics = - ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); std::string funcToDump; if (!haveDump) { @@ -1700,18 +1751,8 @@ void init_triton_ir(py::module &&m) { if (!funcToDump.empty()) haveDump = true; } - if (haveDiagnostics || haveDump) { - context->disableMultithreading(); - } - if (haveDiagnostics) { - context->printOpOnDiagnostic(true); - context->printStackTraceOnDiagnostic(true); - context->getDiagEngine().registerHandler([](Diagnostic &diag) { - llvm::outs() << diag << "\n"; - return success(); - }); - } if (haveDump) { + context->disableMultithreading(); auto printingFlags = OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); printingFlags.enableDebugInfo(); @@ -1741,6 +1782,8 @@ void init_triton_ir(py::module &&m) { // TODO: maybe dump module to file and print error for better // diagnostics + auto *context = mod.getContext(); + auto reproducerPath = triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); if (!reproducerPath.empty()) { @@ -1752,7 +1795,7 @@ void init_triton_ir(py::module &&m) { makeReproducer(anchorName, passes, op, reproducerPath); // But if the pass manager crashes, attempt to generate a local // reproducer instead. - mod.getContext()->disableMultithreading(); + context->disableMultithreading(); self.enableCrashReproducerGeneration(reproducerPath, /*genLocalReproducer=*/true); } @@ -1763,20 +1806,9 @@ void init_triton_ir(py::module &&m) { if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); !debugOnly.empty()) { - llvm::SmallVector split; llvm::SmallVector storage; - llvm::SmallVector debugTypes; - - StringRef(debugOnly.c_str()).split(split, ','); - llvm::transform(split, std::back_inserter(debugTypes), - [&storage](StringRef str) { - // StringRefs are not always null-terminated. - // The purpose for this storage pattern is to - // produce a collection of C-strings that are. - storage.push_back(str.str()); - return storage.back().c_str(); - }); - + llvm::SmallVector debugTypes = + parseCommaSeparatedValues(debugOnly, storage); ::llvm::DebugFlag = true; using namespace llvm; setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); @@ -1787,25 +1819,41 @@ void init_triton_ir(py::module &&m) { self.enableTiming(); } - // Run the pass manager under a source manager diagnostic handler, which - // enables emitted MLIR diagnostics to directly reference Python source - // code. This diagnostic handler will only filter for errors. - struct SourceMgrErrorDiagnosticHandler - : public SourceMgrDiagnosticHandler { - SourceMgrErrorDiagnosticHandler(MLIRContext *ctx) - : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { - setHandler([this](Diagnostic &diag) { - if (diag.getSeverity() != DiagnosticSeverity::Error) - return failure(); - emitDiagnostic(diag); - return success(); - }); + // setting up diagnostics + bool showOperations = false, showStacktraces = false, + showRemarks = false, showWarnings = false; + + if (auto enableDiagnostics = + triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS"); + !enableDiagnostics.empty()) { + llvm::SmallVector storage; + parseCommaSeparatedValues(enableDiagnostics, storage); + for (auto &str : storage) { + if (str == "warnings") { + showWarnings = true; + } else if (str == "remarks") { + showRemarks = true; + } else if (str == "stacktraces") { + showStacktraces = true; + } else if (str == "operations") { + showOperations = true; + } + // we show errors by default, so no need to set it } + } - llvm::SourceMgr sourceMgr; - }; - SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext()); + DiagnosticSeverity minSeverity = showWarnings + ? DiagnosticSeverity::Warning + : DiagnosticSeverity::Error; + minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity; + TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity); + + context->printOpOnDiagnostic(showOperations); + context->printStackTraceOnDiagnostic(showStacktraces); + if (showStacktraces) { + context->disableMultithreading(); + } if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); }); diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index bdf45b021008..86bebdd71af7 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -8,16 +8,12 @@ @contextmanager -def enable_remark_context(): +def enable_diagnostics_context(value): try: - os.environ["MLIR_ENABLE_REMARK"] = "1" + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value yield finally: - os.environ["MLIR_ENABLE_REMARK"] = "0" - - -def is_perf_warning_enabled(): - return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1" + os.environ["MLIR_ENABLE_DIAGNOSTICS"] = "" def is_cuda(): @@ -74,29 +70,39 @@ def matmul_kernel( c = tl.dot(a, b) tl.store(c_block_ptr, c) - with enable_remark_context(): - triton.compile( - triton.compiler.ASTSource( - fn=matmul_kernel, - signature={ - "a_ptr": "*fp32", - "b_ptr": "*fp32", - "c_ptr": "*fp32", - "M": "i32", - "N": "i32", - "K": "i32", - "stride_am": "i32", - "stride_ak": "i32", - "stride_bk": "i32", - "stride_bn": "i32", - "stride_cm": "i32", - "stride_cn": "i32", - }, - constexprs={}, - )) + signature = { + "a_ptr": "*fp32", + "b_ptr": "*fp32", + "c_ptr": "*fp32", + "M": "i32", + "N": "i32", + "K": "i32", + "stride_am": "i32", + "stride_ak": "i32", + "stride_bk": "i32", + "stride_bn": "i32", + "stride_cm": "i32", + "stride_cn": "i32", + } + with enable_diagnostics_context('remarks'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) captured = capfd.readouterr() - assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert "note: see current operation:" not in captured.err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile(triton.compiler.ASTSource( + fn=matmul_kernel, + signature=signature, + constexprs={}, + )) + captured = capfd.readouterr() + assert "note: diagnostic emitted with trace:" in captured.err assert "note: see current operation:" in captured.err @@ -126,25 +132,39 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) tl.store(out_ptr0 + (x4), tmp22, None) XBLOCK = 1024 - with enable_remark_context(): + + astsource_args = { + "fn": ldst_vec, + "signature": { + "in_ptr0": "*i64", + "in_ptr1": "*i64", + "in_ptr2": "*fp16", + "in_ptr3": "*fp32", + "out_ptr0": "*fp16", + "XBLOCK": "constexpr", + }, + "constexprs": {"XBLOCK": XBLOCK}, + } + + with enable_diagnostics_context('remarks'): triton.compile( - triton.compiler.ASTSource( - fn=ldst_vec, - signature={ - "in_ptr0": "*i64", - "in_ptr1": "*i64", - "in_ptr2": "*fp16", - "in_ptr3": "*fp32", - "out_ptr0": "*fp16", - "XBLOCK": "constexpr", - }, - constexprs={"XBLOCK": XBLOCK}, - ), + triton.compiler.ASTSource(**astsource_args), options={"num_warps": 1}, ) _, err = capfd.readouterr() assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + assert "note: see current operation:" not in err + + with enable_diagnostics_context('remarks,operations,stacktraces'): + triton.compile( + triton.compiler.ASTSource(**astsource_args), + options={"num_warps": 1}, + ) + + _, err = capfd.readouterr() + assert "note: see current operation:" in err + assert "note: diagnostic emitted with trace:" in err def test_remark_swp_op_before_operands(capfd, fresh_triton_cache): diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1fcd7dc5b3a9..7563b7515b29 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -238,12 +238,6 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimX = opt.cluster_dims[0] cluster_info.clusterDimY = opt.cluster_dims[1] cluster_info.clusterDimZ = opt.cluster_dims[2] - # Set up Diagnostic - if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": - srcMgr = llvm.source_mgr() - _ = ir.source_mgr_diag(srcMgr, mod.context) - mod.context.printOpOnDiagnostic(True) - # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) @@ -299,11 +293,7 @@ def make_llir(self, src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() - # Set up Diagnostic - if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": - srcMgr = llvm.source_mgr() - _ = ir.source_mgr_diag(srcMgr, mod.context) - mod.context.printOpOnDiagnostic(True) + nvidia.passes.ttnvgpuir.add_lower_mma(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm)