Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions include/triton/Tools/DiagEmitter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef TRITON_TOOLS_DIAG_EMITTER_HPP
#define TRITON_TOOLS_DIAG_EMITTER_HPP
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include <optional>

#define EMIT_PERF_WARNING(op, message) \
if (auto out = mlir::triton::DiagnosticEmitter::getPerfWarningStream(op)) { \
*out << message; \
}
Comment on lines +8 to +11
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it need to be a macro?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was trying to make an API supporting emitPerfWarn(op) << "foo" << "bar" but find it difficult to make it print nothing when the switch is off. Making an empty InFlightDiagnostic triggers an assertion in MLIR when it attempts to attach notes.


namespace mlir::triton {

class DiagnosticEmitter {
// singleton pattern
private:
inline static DiagnosticEmitter *instance{nullptr};
bool shouldEmitPerfWarning;
DiagnosticEmitter() : shouldEmitPerfWarning(false){};
~DiagnosticEmitter() = default;

public:
DiagnosticEmitter(const DiagnosticEmitter &) = delete;
DiagnosticEmitter &operator=(const DiagnosticEmitter &) = delete;

static DiagnosticEmitter *getInstance() {
if (!instance) {
instance = new DiagnosticEmitter();
if (tools::getBoolEnv("MLIR_ENABLE_REMARK")) {
instance->shouldEmitPerfWarning = true;
}
}
return instance;
}

static void setShouldEmitPerfWarning(bool shouldEmit) {
DiagnosticEmitter::getInstance()->shouldEmitPerfWarning = shouldEmit;
}

static std::optional<InFlightDiagnostic>
getPerfWarningStream(const OpState &op) {
if (DiagnosticEmitter::getInstance()->shouldEmitPerfWarning) {
return op->emitRemark();
} else {
return std::nullopt;
}
}
Comment on lines +15 to +48
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not going to be thread safe. We should stay away from globals or singleton

};
} // namespace mlir::triton
#endif
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"LLVM_PASS_PLUGIN_PATH",
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_ENABLE_REMARK",
"MLIR_ENABLE_TIMING",
"TRITON_DEFAULT_FP_FUSION",
"TRITON_DISABLE_LINE_INFO",
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/DiagEmitter.hpp"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

Expand All @@ -37,8 +38,9 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
for (int baseVersion : versionsSupported) {
if (supportMMA(op, baseVersion))
return baseVersion;
if (baseVersion == 3)
op.emitRemark() << "Warning: can't use MMA V3 for the dot op";
if (baseVersion == 3) {
EMIT_PERF_WARNING(op, "MMA V3 is not supported for this dot op. ");
}
}
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "TargetInfo.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"

Expand All @@ -9,6 +10,7 @@
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/DiagEmitter.hpp"

using namespace mlir;
using namespace mlir::triton;
Expand Down Expand Up @@ -204,10 +206,10 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,

if (vec == 1 && numElems > 1) {
int maskValue = !llMask ? -1 : getMaskAlignment(mask);
op->emitRemark() << "Warning: vectorization fails vec = " << vec
<< " origin vec = " << vecOrig
<< " numElems = " << numElems << " mask is " << maskValue
<< "\n";
EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = "
<< vec << " origin vec = " << vecOrig
<< " numElems = " << numElems << " mask is "
<< maskValue << "\n");
}
// Get the LLVM values for pointers
auto ptrElems = unpackLLElements(loc, llPtr, rewriter);
Expand Down Expand Up @@ -431,10 +433,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,

if (vec == 1 && elemsPerThread > 1) {
int mask = !llMask ? -1 : getMaskAlignment(op.getMask());
op->emitRemark() << "Warning: vectorization fails vec = " << vec
<< " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread << " mask is "
<< mask << "\n";
EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = "
<< vec << " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread
<< " mask is " << mask << "\n");
}

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
Expand Down Expand Up @@ -573,9 +575,10 @@ struct AtomicCASOpConversion
}

if (vec == 1 && elemsPerThread > 1)
op->emitRemark() << "Warning: vectorization fails vec = " << vec
<< " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread << "\n";
EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = "
<< vec << " origin vec = " << vecOrig
<< " elemsPerThread = " << elemsPerThread
<< "\n");

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
auto vecTy = vec_ty(valueElemTy, vec);
Expand Down Expand Up @@ -732,9 +735,10 @@ struct AtomicRMWOpConversion
assert((packed == 1 || vec == 1) && "packed or vec must be 1");

if (vec * packed == 1 && numElems > 1)
op->emitRemark() << "Warning: vectorization fails vec = " << vec
<< " packed = " << packed << " origin vec = " << vecOrig
<< " numElems = " << numElems;
EMIT_PERF_WARNING(op, "Warning: vectorization fails vec = "
<< vec << " packed = " << packed
<< " origin vec = " << vecOrig
<< " numElems = " << numElems);

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);

Expand Down