diff --git a/CMakeLists.txt b/CMakeLists.txt index 56564c38964c..c5aa40499e5d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,6 +206,9 @@ if(TRITON_BUILD_PYTHON_MODULE) if (TRITON_BUILD_PROTON) add_subdirectory(third_party/proton) endif() + # We always build proton dialect + list(APPEND TRITON_PLUGIN_NAMES "proton") + add_subdirectory(third_party/proton/dialect) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) @@ -311,6 +314,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + add_subdirectory(third_party/proton/dialect) endif() add_subdirectory(third_party/f2reduce) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index c69e46792c3d..71d75b35dbf0 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -2,6 +2,7 @@ #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "amd/include/TritonAMDGPUTransforms/Passes.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -68,12 +69,13 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } diff --git a/test/Proton/ops.mlir b/test/Proton/ops.mlir new file mode 100644 index 000000000000..22a17e3f0f58 --- /dev/null +++ b/test/Proton/ops.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s + +module { + // CHECK-LABEL: proton_record + tt.func @proton_record() { + // CHECK: proton.record() {isStart = true, regionId = 1 : i32} + // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32} + // CHECK-NEXT: tt.return + proton.record() {isStart = true, regionId = 1 : i32} + proton.record() {isStart = false, regionId = 1 : i32} + tt.return + } +} // end module + +// ----- diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt new file mode 100644 index 000000000000..c7b5413a0e15 --- /dev/null +++ b/third_party/proton/dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR) +endif() diff --git a/third_party/proton/dialect/include/CMakeLists.txt b/third_party/proton/dialect/include/CMakeLists.txt new file mode 100644 index 000000000000..0ca0f41c5af4 --- /dev/null +++ b/third_party/proton/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/include/Dialect/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000000..f18c30ba1a6d --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000000..4645b0ebcd5a --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 000000000000..680a205f08f1 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_ +#define TRITON_DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace proton {} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 000000000000..d469fbb35f6b --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,12 @@ +#ifndef PROTON_ATTRDEFS +#define PROTON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "ProtonDialect.td" + +class Proton_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif // PROTON_ATTRDEFS diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 000000000000..245f2e09a2ec --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,18 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 000000000000..d18a48d5d1a0 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,65 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "ProtonDialect.td" +include "ProtonAttrDefs.td" + +class TT_Proton_Op traits = []> : + Op { +} + +// Proton profiling metric. +def MetricAttr : I32EnumAttr< + "Metric", "", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +// Proton profiling granularity. +def GranularityAttr : I32EnumAttr< + "Granularity", "", + [ + I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">, + I32EnumAttrCase<"WARP", 1, "warp">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods]> { + let summary = "Record a GPU hardware event"; + + let description = [{ + The operator records GPU events from performance counters. + Currently only cycle counter is supported. + + Example: + + ```mlir + proton.record() {isStart = true, regionId = 4 : i32} + ... + proton.record() {isStart = false, regionId = 4 : i32} + ... + proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32} + ... + proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32} + ``` + }]; + let arguments = ( + ins BoolAttr: $isStart, + ConfinedAttr:$regionId, + DefaultValuedAttr:$metric, + DefaultValuedAttr:$granularity + ); + let assemblyFormat = " `(` operands `)` attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt new file mode 100644 index 000000000000..0ca0f41c5af4 --- /dev/null +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/lib/Dialect/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000000..f18c30ba1a6d --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 000000000000..f33061b2d87c --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 000000000000..5eea5cb3cf9e --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 000000000000..60c2852654db --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,25 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/Proton/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::proton; + +void mlir::triton::proton::ProtonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 000000000000..1a0799aea127 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,33 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { + +// -- RecordOp -- +void RecordOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc new file mode 100644 index 000000000000..8046539794e1 --- /dev/null +++ b/third_party/proton/dialect/triton_proton.cc @@ -0,0 +1,20 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_proton(py::module &&m) { + auto passes = m.def_submodule("passes"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +}