Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

#include "cpu/include/TritonCPUToLLVM/Passes.h"
#include "cpu/include/TritonToTritonCPU/Passes.h"
#include "nvidia/include/NVGPUToLLVM/Passes.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
Expand Down Expand Up @@ -60,13 +62,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipeline();

// CPU passes
mlir::triton::cpu::registerTritonToTritonCPUPasses();
mlir::triton::cpu::registerTritonToTritonCPUPipeline();
mlir::triton::cpu::registerTritonCPUToLLVMPasses();
mlir::triton::cpu::registerTritonCPUToLLVMPipeline();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::cpu::TritonCPUDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect,
mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect,
mlir::ROCDL::ROCDLDialect>();
registry
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::cpu::TritonCPUDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect,
mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::memref::MemRefDialect, mlir::vector::VectorDialect,
mlir::tensor::TensorDialect, mlir::gpu::GPUDialect,
mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
mlir::triton::nvgpu::NVGPUDialect, mlir::ROCDL::ROCDLDialect>();
}
2 changes: 1 addition & 1 deletion include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM)
add_public_tablegen_target(TritonCPUConversionPassIncGen)
add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen)
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ class TritonCPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<dialect, name, !listconcat([TritonCPU_AttrTrait], traits), baseCppClass> {

let description = [{
WIP...
}];
let description = [{TritonCPU attr.}];
let attrName = "triton.cpu." # attrMnemonic;
}

#endif
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ def TritonCPU_Dialect : Dialect {
let dependentDialects = [
"triton::TritonDialect",
"tensor::TensorDialect",
"mlir::memref::MemRefDialect",
];

let extraClassDeclaration = [{
void registerTypes();
}];

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}

#endif
67 changes: 67 additions & 0 deletions include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,73 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "triton/Dialect/Triton/IR/TritonTypes.td"
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"

class TTC_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonCPU_Dialect, mnemonic,
!listconcat(traits, [])> {
}

def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> {
let summary = "Extract base memref from a block pointer";

let description = [{
Extract base memref from a block pointer. It covers whole base tensor memory,
not only the block referenced. Base pointer, shape, and strides are used
in the resulting memref. Offsets and block shape are ignored.

}];

let arguments = (ins TT_TensorPtr:$src);

let results = (outs AnyRankedOrUnrankedMemRef:$result);

let hasCanonicalizer = 1;

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> {
let summary = "Extract indices from a block pointer.";

let description = [{
Extract indices that can be used to access the block using its base memref.
Indices are supposed to be used for vector loads/stores with the base
memref extracted from the same block pointer.
}];

let arguments = (ins TT_TensorPtr:$src);

let results = (outs Variadic<Index>:$result);

let builders = [
OpBuilder<(ins "Value":$src)>
];

let hasCanonicalizer = 1;

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> {
let summary = "Build a memref for a pointer.";

let description = [{
Build memref with static shape, offset, strides, and specified base pointer.
}];

let arguments = (ins TT_Ptr:$src);

let results = (outs AnyStaticShapeMemRef:$result);

let hasCanonicalizer = 0;

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

#endif
4 changes: 2 additions & 2 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_subdirectory(TritonToTritonCPU)
#add_subdirectory(TritonToTritonCPU)
add_subdirectory(TritonToTritonGPU)
add_subdirectory(TritonCPUToLLVM)
#add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
38 changes: 37 additions & 1 deletion lib/Dialect/TritonCPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

#include <numeric>

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/TypeSwitch.h"

#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc"

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

//===----------------------------------------------------------------------===//
Expand All @@ -20,6 +24,35 @@ using namespace mlir::triton::cpu;
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc"

void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {}

void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {}

/// Parse an attribute registered to this dialect.
::mlir::Attribute
TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const {
llvm_unreachable("parse stub called");
}

/// Print an attribute registered to this dialect.
void TritonCPUDialect::printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &os) const {
llvm_unreachable("print stub called");
}

void ExtractIndicesOp::build(::mlir::OpBuilder &builder,
::mlir::OperationState &state, Value src) {
assert(triton::isTensorPointerType(src.getType()) &&
"Unexecpeted source type");
auto tensorTy = dyn_cast<RankedTensorType>(
dyn_cast<PointerType>(src.getType()).getPointeeType());
SmallVector<Type> resTypes(tensorTy.getRank(), builder.getIndexType());
build(builder, state, resTypes, src);
}

void TritonCPUDialect::initialize() {
registerTypes();

Expand All @@ -34,6 +67,9 @@ void TritonCPUDialect::initialize() {
>();
}

#define GET_OP_CLASSES
#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc"

// verify TritonCPU ops
LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
Expand Down
68 changes: 68 additions & 0 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
Expand All @@ -18,6 +20,7 @@
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/Transforms/IPO/AlwaysInliner.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -278,6 +281,71 @@ void init_triton_llvm(py::module &&m) {
},
py::arg("mod"), py::arg("opt"), py::arg("triple") = "");

m.def("set_host_target", [](llvm::Module *mod) {
mod->setTargetTriple(llvm::sys::getDefaultTargetTriple());
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error);
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {},
llvm::Reloc::PIC_)};
mod->setDataLayout(machine->createDataLayout());
});

m.def(
"translate_to_host_asm",
[](std::string llvmIR, bool enable_fp_fusion) -> py::object {
std::string res;
{
// when allow_threads goes out of scope, gil will be released
py::gil_scoped_release allow_threads;
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
res =
translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(),
llvm::sys::getHostCPUName().str(), "", {},
enable_fp_fusion, false);
}
return py::str(res);
},
ret::take_ownership);

m.def(
"translate_to_bc",
[](const std::string llvmIR) -> py::object {
py::gil_scoped_release allow_threads;
// create LLVM module
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
if (!module) {
llvm::report_fatal_error(
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
// Write bitcode to a buffer.
llvm::SmallVector<char, 0> buf;
llvm::BitcodeWriter writer(buf);
writer.writeModule(*module);
writer.writeStrtab();
std::string bitcode(buf.begin(), buf.end());
return py::bytes(bitcode);
},
ret::take_ownership);

m.def(
"translate_to_asm",
[](std::string llvmIR, std::string triple, std::string proc,
Expand Down
5 changes: 3 additions & 2 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void init_triton_passes_ttir(py::module &&m) {
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
createConvertTritonToTritonGPUPass, const std::string &,
int, int, int);
ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir",
createConvertTritonToTritonCPUPass);
// ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir",
// createConvertTritonToTritonCPUPass);
}

void init_triton_passes_ttgpuir(py::module &&m) {
Expand Down Expand Up @@ -75,6 +75,7 @@ void init_triton_passes_convert(py::module &&m) {
ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass);
ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass);
ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass);
ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass);
}

void init_triton_passes_llvmir(py::module &&m) {
Expand Down
1 change: 1 addition & 0 deletions python/test/unit/language/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

def pytest_configure(config):
config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test")
config.addinivalue_line("markers", "cpu: indicate whether test is supported on cpu")
Loading