Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.

# Changelog

Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
for (int baseVersion : versionsSupported) {
if (supportMMA(op, baseVersion))
return baseVersion;
if (baseVersion == 3)
op.emitRemark() << "Warning: can't use MMA V3 for the dot op";
}
return 0;
}
Expand Down
13 changes: 12 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
Expand All @@ -27,6 +28,7 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/SourceMgr.h"

namespace {

Expand Down Expand Up @@ -201,7 +203,16 @@ void init_triton_ir(py::module &&m) {
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());
py::class_<MLIRContext>(m, "context", py::module_local())
.def(py::init<>())
.def("printOpOnDiagnostic",
[](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); })
.def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) {
self.printStackTraceOnDiagnostic(v);
});
py::class_<SourceMgrDiagnosticHandler>(m, "source_mgr_diag",
py::module_local())
.def(py::init<llvm::SourceMgr &, MLIRContext *>());

m.def("load_dialects", [](MLIRContext &context) {
DialectRegistry registry;
Expand Down
3 changes: 3 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
Expand Down Expand Up @@ -150,6 +151,8 @@ void init_triton_llvm(py::module &&m) {

py::class_<llvm::LLVMContext>(m, "context", py::module_local())
.def(py::init<>());
py::class_<llvm::SourceMgr>(m, "source_mgr", py::module_local())
.def(py::init<>());

py::class_<llvm::Module::FunctionListType>(m, "function_list")
.def(
Expand Down
47 changes: 47 additions & 0 deletions python/test/unit/test_perf_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import triton
import triton.language as tl
import os
import pytest
import torch


def is_perf_warning_enabled():
return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1'


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


def test_mma_remark(capfd):
if is_cuda():
capability = torch.cuda.get_device_capability()
if capability[0] < 9:
pytest.skip("Requires sm >= 90 to run")

os.environ['MLIR_ENABLE_REMARK'] = '1'

@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn):
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
block_shape=(32, 128), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
block_shape=(128, 32), order=(0, 1))
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
block_shape=(32, 32), order=(1, 0))
a = tl.load(a_block_ptr)
b = tl.load(b_block_ptr)
c = tl.dot(a, b)
tl.store(c_block_ptr, c)

triton.compile(
triton.compiler.ASTSource(
fn=matmul_kernel, signature={
0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9:
'i32', 10: 'i32', 11: 'i32'
}, constants={}))
captured = capfd.readouterr()

assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark"
assert "note: see current operation:" in captured.err
os.environ['MLIR_ENABLE_REMARK'] = '0'
5 changes: 5 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def make_ttgir(mod, metadata, opt, capability):
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document in README?

srcMgr = llvm.source_mgr()
diag = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
Expand Down