Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
- `MLIR_ENABLE_DIAGNOSTICS=<comma-separated>` controls diagnostic emission in MLIR.
Options are: `warnings`, `remarks`, `stacktraces`, `operations`.
Use comma-separated values to customize output. For example,
`MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations,
while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with
stacktraces. By default, only errors are shown. Setting `warnings` includes
errors and warnings; `remarks` includes errors, warnings, and remarks.
- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`.
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
Expand Down
134 changes: 91 additions & 43 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,42 @@ class TritonOpBuilder {
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};

// Run the pass manager under a source manager diagnostic handler, which
// enables emitted MLIR diagnostics to directly reference Python source
// code. This diagnostic handler supports filtering diagnostic info by
// severity levels.
struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
TritonSourceMgrDiagnosticHandler(MLIRContext *ctx,
DiagnosticSeverity minSeverity)
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
setHandler([this, minSeverity](Diagnostic &diag) {
auto severity = diag.getSeverity();
switch (severity) {
case DiagnosticSeverity::Error:
break;
case DiagnosticSeverity::Warning:
if (minSeverity == DiagnosticSeverity::Error)
return success();
break;
case DiagnosticSeverity::Remark:
if (minSeverity == DiagnosticSeverity::Error ||
minSeverity == DiagnosticSeverity::Warning)
return success();
break;
case DiagnosticSeverity::Note:
// notes are handled somewhere else.
return failure();
default:
llvm_unreachable("Unknown diagnostic severity");
}
emitDiagnostic(diag);
return success();
});
}

llvm::SourceMgr sourceMgr;
};

std::string locationToString(Location loc) {
std::string str;
llvm::raw_string_ostream os(str);
Expand All @@ -148,6 +184,23 @@ std::string locationToString(Location loc) {
return str;
}

// Function to parse a comma-separated string into a vector of C-style strings
llvm::SmallVector<const char *, 3>
parseCommaSeparatedValues(const std::string &input,
llvm::SmallVector<std::string, 3> &storage) {
llvm::SmallVector<StringRef, 3> split;
llvm::SmallVector<const char *, 3> result;
StringRef(input.c_str()).split(split, ',');
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
// StringRefs are not always null-terminated.
// The purpose for this storage pattern is to
// produce a collection of C-strings that are.
storage.push_back(str.str());
return storage.back().c_str();
});
return result;
}

void outputWarning(Location loc, const std::string &msg) {
std::string locStr = locationToString(loc);

Expand Down Expand Up @@ -1691,27 +1744,15 @@ void init_triton_ir(py::module &&m) {
.def("enable_debug",
[](PassManager &self) {
auto *context = self.getContext();
bool haveDiagnostics =
::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
std::string funcToDump;
if (!haveDump) {
funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP");
if (!funcToDump.empty())
haveDump = true;
}
if (haveDiagnostics || haveDump) {
context->disableMultithreading();
}
if (haveDiagnostics) {
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});
}
if (haveDump) {
context->disableMultithreading();
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
Expand Down Expand Up @@ -1741,6 +1782,8 @@ void init_triton_ir(py::module &&m) {
// TODO: maybe dump module to file and print error for better
// diagnostics

auto *context = mod.getContext();

auto reproducerPath =
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
if (!reproducerPath.empty()) {
Expand All @@ -1752,7 +1795,7 @@ void init_triton_ir(py::module &&m) {
makeReproducer(anchorName, passes, op, reproducerPath);
// But if the pass manager crashes, attempt to generate a local
// reproducer instead.
mod.getContext()->disableMultithreading();
context->disableMultithreading();
self.enableCrashReproducerGeneration(reproducerPath,
/*genLocalReproducer=*/true);
}
Expand All @@ -1763,20 +1806,9 @@ void init_triton_ir(py::module &&m) {

if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY");
!debugOnly.empty()) {
llvm::SmallVector<StringRef, 3> split;
llvm::SmallVector<std::string, 3> storage;
llvm::SmallVector<const char *, 3> debugTypes;

StringRef(debugOnly.c_str()).split(split, ',');
llvm::transform(split, std::back_inserter(debugTypes),
[&storage](StringRef str) {
// StringRefs are not always null-terminated.
// The purpose for this storage pattern is to
// produce a collection of C-strings that are.
storage.push_back(str.str());
return storage.back().c_str();
});

llvm::SmallVector<const char *, 3> debugTypes =
parseCommaSeparatedValues(debugOnly, storage);
::llvm::DebugFlag = true;
using namespace llvm;
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
Expand All @@ -1787,25 +1819,41 @@ void init_triton_ir(py::module &&m) {
self.enableTiming();
}

// Run the pass manager under a source manager diagnostic handler, which
// enables emitted MLIR diagnostics to directly reference Python source
// code. This diagnostic handler will only filter for errors.
struct SourceMgrErrorDiagnosticHandler
: public SourceMgrDiagnosticHandler {
SourceMgrErrorDiagnosticHandler(MLIRContext *ctx)
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
setHandler([this](Diagnostic &diag) {
if (diag.getSeverity() != DiagnosticSeverity::Error)
return failure();
emitDiagnostic(diag);
return success();
});
// setting up diagnostics
bool showOperations = false, showStacktraces = false,
showRemarks = false, showWarnings = false;

if (auto enableDiagnostics =
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
!enableDiagnostics.empty()) {
llvm::SmallVector<std::string, 3> storage;
parseCommaSeparatedValues(enableDiagnostics, storage);
for (auto &str : storage) {
if (str == "warnings") {
showWarnings = true;
} else if (str == "remarks") {
showRemarks = true;
} else if (str == "stacktraces") {
showStacktraces = true;
} else if (str == "operations") {
showOperations = true;
}
// we show errors by default, so no need to set it
}
}
Comment thread
sfzhu93 marked this conversation as resolved.

llvm::SourceMgr sourceMgr;
};
SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext());
DiagnosticSeverity minSeverity = showWarnings
? DiagnosticSeverity::Warning
: DiagnosticSeverity::Error;
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;

TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);

context->printOpOnDiagnostic(showOperations);
context->printStackTraceOnDiagnostic(showStacktraces);
if (showStacktraces) {
context->disableMultithreading();
}
if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
});
Expand Down
102 changes: 61 additions & 41 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@


@contextmanager
def enable_remark_context():
def enable_diagnostics_context(value):
try:
os.environ["MLIR_ENABLE_REMARK"] = "1"
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value
yield
finally:
os.environ["MLIR_ENABLE_REMARK"] = "0"


def is_perf_warning_enabled():
return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1"
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = ""


def is_cuda():
Expand Down Expand Up @@ -74,29 +70,39 @@ def matmul_kernel(
c = tl.dot(a, b)
tl.store(c_block_ptr, c)

with enable_remark_context():
triton.compile(
triton.compiler.ASTSource(
fn=matmul_kernel,
signature={
"a_ptr": "*fp32",
"b_ptr": "*fp32",
"c_ptr": "*fp32",
"M": "i32",
"N": "i32",
"K": "i32",
"stride_am": "i32",
"stride_ak": "i32",
"stride_bk": "i32",
"stride_bn": "i32",
"stride_cm": "i32",
"stride_cn": "i32",
},
constexprs={},
))
signature = {
"a_ptr": "*fp32",
"b_ptr": "*fp32",
"c_ptr": "*fp32",
"M": "i32",
"N": "i32",
"K": "i32",
"stride_am": "i32",
"stride_ak": "i32",
"stride_bk": "i32",
"stride_bn": "i32",
"stride_cm": "i32",
"stride_cn": "i32",
}
with enable_diagnostics_context('remarks'):
triton.compile(triton.compiler.ASTSource(
fn=matmul_kernel,
signature=signature,
constexprs={},
))
captured = capfd.readouterr()

assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
assert "note: see current operation:" not in captured.err

with enable_diagnostics_context('remarks,operations,stacktraces'):
triton.compile(triton.compiler.ASTSource(
fn=matmul_kernel,
signature=signature,
constexprs={},
))
captured = capfd.readouterr()
assert "note: diagnostic emitted with trace:" in captured.err
assert "note: see current operation:" in captured.err


Expand Down Expand Up @@ -126,25 +132,39 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
tl.store(out_ptr0 + (x4), tmp22, None)

XBLOCK = 1024
with enable_remark_context():

astsource_args = {
"fn": ldst_vec,
"signature": {
"in_ptr0": "*i64",
"in_ptr1": "*i64",
"in_ptr2": "*fp16",
"in_ptr3": "*fp32",
"out_ptr0": "*fp16",
"XBLOCK": "constexpr",
},
"constexprs": {"XBLOCK": XBLOCK},
}

with enable_diagnostics_context('remarks'):
triton.compile(
triton.compiler.ASTSource(
fn=ldst_vec,
signature={
"in_ptr0": "*i64",
"in_ptr1": "*i64",
"in_ptr2": "*fp16",
"in_ptr3": "*fp32",
"out_ptr0": "*fp16",
"XBLOCK": "constexpr",
},
constexprs={"XBLOCK": XBLOCK},
),
triton.compiler.ASTSource(**astsource_args),
options={"num_warps": 1},
)

_, err = capfd.readouterr()
assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark"
assert "note: see current operation:" not in err

with enable_diagnostics_context('remarks,operations,stacktraces'):
triton.compile(
triton.compiler.ASTSource(**astsource_args),
options={"num_warps": 1},
)

_, err = capfd.readouterr()
assert "note: see current operation:" in err
assert "note: diagnostic emitted with trace:" in err


def test_remark_swp_op_before_operands(capfd, fresh_triton_cache):
Expand Down
12 changes: 1 addition & 11 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,6 @@ def make_ttgir(mod, metadata, opt, capability):
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
srcMgr = llvm.source_mgr()
_ = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
Expand Down Expand Up @@ -299,11 +293,7 @@ def make_llir(self, src, metadata, options, capability):
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
srcMgr = llvm.source_mgr()
_ = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)

nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.convert.add_scf_to_cf(pm)
Expand Down