Skip to content

Commit

Permalink
Only print out diagnostic messages if an environment variable is set
Browse files Browse the repository at this point in the history
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
  • Loading branch information
gflegar committed Feb 19, 2024
1 parent e16a58f commit b4e2d1f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ inline const std::set<std::string> ENV_VARS = {
"MLIR_ENABLE_DUMP",
"TRITON_DISABLE_LINE_INFO",
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
"MLIR_ENABLE_DIAGNOSTICS",
};

namespace tools {
Expand Down
47 changes: 27 additions & 20 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1423,26 +1423,33 @@ void init_triton_ir(py::module &&m) {
.def("enable_debug",
[](PassManager &self) {
auto *context = self.getContext();
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->disableMultithreading();
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});

if (!triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"))
return;
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
auto print_always = [](Pass *, Operation *) { return true; };
self.enableIRPrinting(
/*shouldPrintBeforePass=*/print_always,
/*shouldPrintAfterPass=*/print_always,
/*printModuleScope=*/true,
/*printAfterOnlyOnChange=*/false,
/*printAfterOnlyOnFailure*/ true, llvm::dbgs(), printingFlags);
bool have_diagnostics =
::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
bool have_dump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
if (have_diagnostics || have_dump) {
context->disableMultithreading();
}
if (have_diagnostics) {
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});
}
if (have_dump) {
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
auto print_always = [](Pass *, Operation *) { return true; };
self.enableIRPrinting(
/*shouldPrintBeforePass=*/print_always,
/*shouldPrintAfterPass=*/print_always,
/*printModuleScope=*/true,
/*printAfterOnlyOnChange=*/false,
/*printAfterOnlyOnFailure*/ true, llvm::dbgs(),
printingFlags);
}
})
.def("run", [](PassManager &self, ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
Expand Down

0 comments on commit b4e2d1f

Please sign in to comment.