Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ if(TRITON_BUILD_PYTHON_MODULE)
if (TRITON_BUILD_PROTON)
add_subdirectory(third_party/proton)
endif()
# We always build proton dialect
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/dialect)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
Expand Down Expand Up @@ -311,6 +314,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
add_subdirectory(third_party/proton/dialect)
endif()

add_subdirectory(third_party/f2reduce)
Expand Down
18 changes: 10 additions & 8 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
Expand Down Expand Up @@ -68,12 +69,13 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
mlir::triton::amdgpu::TritonAMDGPUDialect,
mlir::ROCDL::ROCDLDialect>();
registry
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
mlir::triton::amdgpu::TritonAMDGPUDialect,
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>();
}
15 changes: 15 additions & 0 deletions test/Proton/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s

module {
// CHECK-LABEL: proton_record
tt.func @proton_record() {
// CHECK: proton.record() {isStart = true, regionId = 1 : i32}
// CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
// CHECK-NEXT: tt.return
proton.record() {isStart = true, regionId = 1 : i32}
proton.record() {isStart = false, regionId = 1 : i32}
tt.return
}
} // end module

// -----
7 changes: 7 additions & 0 deletions third_party/proton/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR)
endif()
1 change: 1 addition & 0 deletions third_party/proton/dialect/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Proton)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS ProtonOps.td)
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton)
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc)
add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc)
add_public_tablegen_target(ProtonTableGen)

set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td)
mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(ProtonAttrDefsIncGen)
23 changes: 23 additions & 0 deletions third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_
#define TRITON_DIALECT_PROTON_IR_DIALECT_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"

#define GET_OP_CLASSES
#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace proton {} // namespace proton
} // namespace triton
} // namespace mlir

#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef PROTON_ATTRDEFS
#define PROTON_ATTRDEFS

include "mlir/IR/AttrTypeBase.td"
include "ProtonDialect.td"

class Proton_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<Proton_Dialect, name, traits, baseCppClass> {
}

#endif // PROTON_ATTRDEFS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef PROTON_DIALECT
#define PROTON_DIALECT

include "mlir/IR/OpBase.td"

def Proton_Dialect : Dialect {
let name = "proton";
let cppNamespace = "::mlir::triton::proton";

let description = [{
Proton Dialect provides core ops for building third-party compiler-based
performance profiling and analysis tools.
}];

let dependentDialects = [];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef PROTON_OPS
#define PROTON_OPS

include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "ProtonDialect.td"
include "ProtonAttrDefs.td"

class TT_Proton_Op<string mnemonic, list<Trait> traits = []> :
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
}

// Proton profiling metric.
def MetricAttr : I32EnumAttr<
"Metric", "",
[
I32EnumAttrCase<"CYCLE", 0, "cycle">,
]> {
let cppNamespace = "::mlir::triton::proton";
}

// Proton profiling granularity.
def GranularityAttr : I32EnumAttr<
"Granularity", "",
[
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
I32EnumAttrCase<"WARP", 1, "warp">,
]> {
let cppNamespace = "::mlir::triton::proton";
}

def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Record a GPU hardware event";

let description = [{
The operator records GPU events from performance counters.
Currently only cycle counter is supported.

Example:

```mlir
proton.record() {isStart = true, regionId = 4 : i32}
...
proton.record() {isStart = false, regionId = 4 : i32}
...
proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32}
...
proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32}
```
}];
let arguments = (
ins BoolAttr: $isStart,
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could region id be a dynamic value in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the dynamic value thing, I thought about it and I think should have an "version" as an argument, like record(version:i32){...attr...}. This would be useful for loop iterations that we have multiple instances of the recordOP and we can distinguish them. To make the region_id dynamic, do you have any use cases?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't, just curious

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to add them (including the "version" arg) when we actually try to build the support around it, and keep it simple for now. What do you think?

DefaultValuedAttr<MetricAttr, "Metric::CYCLE">:$metric,
DefaultValuedAttr<GranularityAttr, "Granularity::WARPGROUP">:$granularity
);
let assemblyFormat = " `(` operands `)` attr-dict";
}

#endif // PROTON_OPS
1 change: 1 addition & 0 deletions third_party/proton/dialect/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions third_party/proton/dialect/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(Proton)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
13 changes: 13 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_triton_library(ProtonIR
Dialect.cpp
Ops.cpp

DEPENDS
ProtonTableGen
ProtonAttrDefsIncGen

LINK_LIBS PUBLIC
MLIRLLVMDialect
TritonIR
TritonGPUIR
)
25 changes: 25 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"

// clang-format off
#include "Dialect/Proton/IR/Dialect.h"
#include "Dialect/Proton/IR/Dialect.cpp.inc"
// clang-format on

using namespace mlir;
using namespace mlir::triton::proton;

void mlir::triton::proton::ProtonDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc"
>();

addOperations<
#define GET_OP_LIST
#include "Dialect/Proton/IR/Ops.cpp.inc"
>();
}

#define GET_ATTRDEF_CLASSES
#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc"
33 changes: 33 additions & 0 deletions third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "Dialect/Proton/IR/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"

#define GET_OP_CLASSES
#include "Dialect/Proton/IR/Ops.cpp.inc"
#include "Dialect/Proton/IR/OpsEnums.cpp.inc"

namespace mlir {
namespace triton {
namespace proton {

// -- RecordOp --
void RecordOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}

} // namespace proton
} // namespace triton
} // namespace mlir
20 changes: 20 additions & 0 deletions third_party/proton/dialect/triton_proton.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "Dialect/Proton/IR/Dialect.h"
#include "mlir/Pass/PassManager.h"
#include "passes.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

namespace py = pybind11;

void init_triton_proton(py::module &&m) {
auto passes = m.def_submodule("passes");

// load dialects
m.def("load_dialects", [](mlir::MLIRContext &context) {
mlir::DialectRegistry registry;
registry.insert<mlir::triton::proton::ProtonDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
}
Loading