diff --git a/include/triton/Tools/DiagEmitter.hpp b/include/triton/Tools/DiagEmitter.hpp new file mode 100644 index 000000000000..d05d325d7cb7 --- /dev/null +++ b/include/triton/Tools/DiagEmitter.hpp @@ -0,0 +1,51 @@ +#ifndef TRITON_TOOLS_DIAG_EMITTER_HPP +#define TRITON_TOOLS_DIAG_EMITTER_HPP +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +#define EMIT_PERF_WARNING(op, message) \ + if (auto out = mlir::triton::DiagnosticEmitter::getPerfWarningStream(op)) { \ + *out << message; \ + } + +namespace mlir::triton { + +class DiagnosticEmitter { + // singleton pattern +private: + inline static DiagnosticEmitter *instance{nullptr}; + bool shouldEmitPerfWarning; + DiagnosticEmitter() : shouldEmitPerfWarning(false){}; + ~DiagnosticEmitter() = default; + +public: + DiagnosticEmitter(const DiagnosticEmitter &) = delete; + DiagnosticEmitter &operator=(const DiagnosticEmitter &) = delete; + + static DiagnosticEmitter *getInstance() { + if (!instance) { + instance = new DiagnosticEmitter(); + if (tools::getBoolEnv("MLIR_ENABLE_REMARK")) { + instance->shouldEmitPerfWarning = true; + } + } + return instance; + } + + static void setShouldEmitPerfWarning(bool shouldEmit) { + DiagnosticEmitter::getInstance()->shouldEmitPerfWarning = shouldEmit; + } + + static std::optional + getPerfWarningStream(const OpState &op) { + if (DiagnosticEmitter::getInstance()->shouldEmitPerfWarning) { + return op->emitRemark(); + } else { + return std::nullopt; + } + } +}; +} // namespace mlir::triton +#endif diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index e5132b6d36e5..8e00feec446d 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -23,6 +23,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "LLVM_PASS_PLUGIN_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_REMARK", "MLIR_ENABLE_TIMING", "TRITON_DEFAULT_FP_FUSION", "TRITON_DISABLE_LINE_INFO", diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 4dccc85da34f..187bd8bc0c85 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -12,6 +12,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/DiagEmitter.hpp" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -37,8 +38,9 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { for (int baseVersion : versionsSupported) { if (supportMMA(op, baseVersion)) return baseVersion; - if (baseVersion == 3) - op.emitRemark() << "Warning: can't use MMA V3 for the dot op"; + if (baseVersion == 3) { + EMIT_PERF_WARNING(op, "MMA V3 is not supported for this dot op. "); + } } return 0; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270a9..2dd23095331e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,4 +1,5 @@ #include "TargetInfo.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" @@ -9,6 +10,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/DiagEmitter.hpp" using namespace mlir; using namespace mlir::triton; @@ -204,10 +206,10 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, if (vec == 1 && numElems > 1) { int maskValue = !llMask ? -1 : getMaskAlignment(mask); - op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig - << " numElems = " << numElems << " mask is " << maskValue - << "\n"; + EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = " + << vec << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " + << maskValue << "\n"); } // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -431,10 +433,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, if (vec == 1 && elemsPerThread > 1) { int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); - op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig - << " elemsPerThread = " << elemsPerThread << " mask is " - << mask << "\n"; + EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = " + << vec << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread + << " mask is " << mask << "\n"); } Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); @@ -573,9 +575,10 @@ struct AtomicCASOpConversion } if (vec == 1 && elemsPerThread > 1) - op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " origin vec = " << vecOrig - << " elemsPerThread = " << elemsPerThread << "\n"; + EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = " + << vec << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread + << "\n"); Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); @@ -732,9 +735,10 @@ struct AtomicRMWOpConversion assert((packed == 1 || vec == 1) && "packed or vec must be 1"); if (vec * packed == 1 && numElems > 1) - op->emitRemark() << "Warning: vectorization fails vec = " << vec - << " packed = " << packed << " origin vec = " << vecOrig - << " numElems = " << numElems; + EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = " + << vec << " packed = " << packed + << " origin vec = " << vecOrig + << " numElems = " << numElems); Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);