From 098abcc0482b924976a37680c29454ae5047459b Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 2 May 2024 10:52:16 -0700 Subject: [PATCH 01/13] Support basic lowering through vector dialect in CPU backend. Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 26 +- .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 2 +- .../Dialect/TritonCPU/IR/TritonCPUAttrDefs.td | 5 +- .../Dialect/TritonCPU/IR/TritonCPUDialect.td | 3 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 51 +++ lib/Conversion/CMakeLists.txt | 4 +- lib/Dialect/TritonCPU/IR/Dialect.cpp | 37 +- python/src/llvm.cc | 67 ++++ python/src/passes.cc | 5 +- python/test/unit/language/conftest.py | 1 + python/test/unit/language/test_core.py | 56 +++ python/triton/runtime/jit.py | 7 - third_party/cpu/CMakeLists.txt | 5 + third_party/cpu/backend/compiler.py | 73 ++-- third_party/cpu/backend/driver.cpp | 224 ++++++++++++ third_party/cpu/backend/driver.py | 326 ++++++++++++++++-- third_party/cpu/include/CMakeLists.txt | 2 + .../include/TritonCPUToLLVM/CMakeLists.txt | 3 + .../cpu/include/TritonCPUToLLVM/Passes.h | 36 ++ .../cpu/include/TritonCPUToLLVM/Passes.td | 46 +++ .../include/TritonToTritonCPU/CMakeLists.txt | 3 + .../cpu/include/TritonToTritonCPU/Passes.h | 38 ++ .../cpu/include/TritonToTritonCPU/Passes.td | 77 +++++ third_party/cpu/lib/CMakeLists.txt | 2 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 13 + .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 278 +++++++++++++++ .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 98 ++++++ .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 277 +++++++++++++++ .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 25 ++ .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 43 +++ .../cpu/lib/TritonCPUToLLVM/TypeConverter.h | 22 ++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 16 + .../ConvertControlFlowOps.cpp | 121 +++++++ .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 102 ++++++ .../ConvertElementwiseOps.cpp | 235 +++++++++++++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 277 +++++++++++++++ .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 191 ++++++++++ .../lib/TritonToTritonCPU/OpTypeConversion.h | 37 ++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 27 ++ .../lib/TritonToTritonCPU/TypeConverter.cpp | 51 +++ .../cpu/lib/TritonToTritonCPU/TypeConverter.h | 19 + third_party/cpu/triton_cpu.cc | 45 ++- 42 files changed, 2891 insertions(+), 85 deletions(-) create mode 100644 third_party/cpu/backend/driver.cpp create mode 100644 third_party/cpu/include/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.h create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.td create mode 100644 third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.h create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.td create mode 100644 third_party/cpu/lib/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 0541791b6d43..dc0018eb8918 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -14,6 +14,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -60,13 +62,21 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUReorderInstructions(); mlir::registerTritonAMDGPUStreamPipeline(); + // CPU passes + mlir::triton::cpu::registerTritonToTritonCPUPasses(); + mlir::triton::cpu::registerTritonToTritonCPUPipeline(); + mlir::triton::cpu::registerTritonCPUToLLVMPasses(); + mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); + // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } diff --git a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt b/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt index 64b36523d35d..0936dff12d91 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) +add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td index df933dd49511..57f6c7c9bd71 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td @@ -17,9 +17,8 @@ class TritonCPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { - let description = [{ - WIP... - }]; + let description = [{TritonCPU attr.}]; + let attrName = "triton.cpu." # attrMnemonic; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td index 9ccac13f0b58..260db2743046 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -17,6 +17,7 @@ def TritonCPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "tensor::TensorDialect", + "mlir::memref::MemRefDialect", ]; let extraClassDeclaration = [{ @@ -24,6 +25,8 @@ def TritonCPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 16d9e433e899..bb7417ebd03e 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -7,6 +7,57 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +class TTC_Op traits = []> : + Op { +} + +def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { + let summary = "Extract base memref from a block pointer"; + + let description = [{ + Extract base memref from a block pointer. It covers whole base tensor memory, + not only the block referenced. Base pointer, shape, and strides are used + in the resulting memref. Offsets and block shape are ignored. + + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { + let summary = "Extract indices from a block pointer."; + + let description = [{ + Extract indices that can be used to access the block using its base memref. + Indices are supposed to be used for vector loads/stores with the base + memref extracted from the same block pointer. + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs Variadic:$result); + + let builders = [ + OpBuilder<(ins "Value":$src)> + ]; + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} #endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 5c3aa2c1a827..83db4ae41607 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(TritonToTritonCPU) +#add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -add_subdirectory(TritonCPUToLLVM) +#add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e28a65358dca..e5eb53caf686 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -2,16 +2,19 @@ #include +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" + using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::cpu; //===----------------------------------------------------------------------===// @@ -20,6 +23,35 @@ using namespace mlir::triton::cpu; #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" +void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +/// Parse an attribute registered to this dialect. +::mlir::Attribute +TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const { + llvm_unreachable("parse stub called"); +} + +/// Print an attribute registered to this dialect. +void TritonCPUDialect::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &os) const { + llvm_unreachable("print stub called"); +} + +void ExtractIndicesOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, Value src) { + assert(triton::isTensorPointerType(src.getType()) && + "Unexecpeted source type"); + auto tensorTy = dyn_cast( + dyn_cast(src.getType()).getPointeeType()); + SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); + build(builder, state, resTypes, src); +} + void TritonCPUDialect::initialize() { registerTypes(); @@ -34,6 +66,9 @@ void TritonCPUDialect::initialize() { >(); } +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 1b061d599772..ae5798ed1fb8 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -3,6 +3,8 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -18,6 +20,7 @@ #include "llvm/Support/CodeGen.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include @@ -278,6 +281,70 @@ void init_triton_llvm(py::module &&m) { }, py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + m.def("set_host_target", [](llvm::Module *mod) { + mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + std::unique_ptr machine{target->createTargetMachine( + mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, + llvm::Reloc::PIC_)}; + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "translate_to_host_asm", + [](std::string llvmIR) -> py::object { + std::string res; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + res = translateLLVMIRToASM( + *module, llvm::sys::getDefaultTargetTriple(), + llvm::sys::getHostCPUName().str(), "", {}, false, false); + } + return py::str(res); + }, + ret::take_ownership); + + m.def( + "translate_to_bc", + [](const std::string llvmIR) -> py::object { + py::gil_scoped_release allow_threads; + // create LLVM module + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + // Write bitcode to a buffer. + llvm::SmallVector buf; + llvm::BitcodeWriter writer(buf); + writer.writeModule(*module); + writer.writeStrtab(); + std::string bitcode(buf.begin(), buf.end()); + return py::bytes(bitcode); + }, + ret::take_ownership); + m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, diff --git a/python/src/passes.cc b/python/src/passes.cc index c112816e02fb..df7d9faa9052 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -43,8 +43,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - createConvertTritonToTritonCPUPass); + // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + // createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -75,6 +75,7 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/python/test/unit/language/conftest.py b/python/test/unit/language/conftest.py index 091f9ea41e7f..44615b8b883b 100644 --- a/python/test/unit/language/conftest.py +++ b/python/test/unit/language/conftest.py @@ -3,3 +3,4 @@ def pytest_configure(config): config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") + config.addinivalue_line("markers", "cpu: indicate whether test is supported on cpu") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2fa30c83a00e..b9bab8b3e21f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -234,6 +234,7 @@ def filter_layouts(layouts): return layouts +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -409,6 +410,7 @@ def test_dtype_codegen(): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -449,6 +451,7 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): test_broadcast=(op != "%")) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -475,6 +478,7 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -489,6 +493,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) +@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -525,6 +530,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -549,6 +555,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -574,6 +581,7 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -598,6 +606,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -632,6 +641,7 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- +@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -663,6 +673,7 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -678,6 +689,7 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -726,6 +738,7 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -789,6 +802,7 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -805,6 +819,7 @@ def _kernel(dst): # --------------- # test where # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -857,6 +872,7 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -901,6 +917,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -915,6 +932,7 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -925,6 +943,7 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -946,6 +965,7 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -971,6 +991,7 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -983,6 +1004,7 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1023,6 +1045,7 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): @@ -1062,6 +1085,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1128,6 +1152,7 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" +@pytest.mark.cpu # TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] @@ -1197,6 +1222,7 @@ def tuples_fn(a, b): a * b +@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1289,6 +1315,7 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1535,6 +1562,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -1664,6 +1692,7 @@ def kernel(X, Y, Z, N: tl.constexpr): assert z.unique().size(0) == z.size(0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1688,6 +1717,7 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): assert torch.all(output == ref) +@pytest.mark.cpu def test_load_store_same_ptr(device): @triton.jit() @@ -1916,6 +1946,7 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): return output +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) def test_convert_float16_to_float32(in_dtype, device): @@ -3427,6 +3458,7 @@ def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.co assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) @pytest.mark.parametrize("shape", [(), (1, ), (128, )]) @@ -3466,6 +3498,7 @@ def kernel_dynamic(out, val, dtype: tl.constexpr): assert torch.all(out_dynamic == 2) +@pytest.mark.cpu @pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), ('float("-inf")', "f32"), ('float("nan")', "f32"), ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) @@ -3490,6 +3523,7 @@ def pass_const(a, b, choose_b): return a +@pytest.mark.cpu @pytest.mark.parametrize("choose_const", [True, False]) @pytest.mark.parametrize("constexpr", [True, False]) @pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) @@ -3569,6 +3603,7 @@ def _kernel(out): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -3843,6 +3878,7 @@ def _impl(value=10): return value +@pytest.mark.cpu @pytest.mark.interpreter def test_default(device): value = 5 @@ -3868,6 +3904,7 @@ def _kernel(ret0, ret1, value=3): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_noop(device): @@ -3895,6 +3932,7 @@ def kernel(x): kernel[(1, )](x) +@pytest.mark.cpu @pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) @@ -3918,6 +3956,7 @@ def kernel(VALUE, X): # -------------------- +@pytest.mark.cpu @pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: @@ -3939,6 +3978,7 @@ def kernel(VALUE, X): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -3976,6 +4016,7 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -3989,6 +4030,7 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4006,6 +4048,7 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): @@ -4033,6 +4076,7 @@ def generate_kernel(shape_x, shape_z): np.testing.assert_equal(z, to_numpy(z_tri)) +@pytest.mark.cpu def test_reshape_err(device): @triton.jit @@ -4107,6 +4151,7 @@ def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): tl.store(ptr + offsets, vec, mask=mask) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("type", ["inline", "noinline"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4138,6 +4183,7 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("if_type", [ "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", @@ -4198,6 +4244,7 @@ def _kernel(dst): _kernel[(1, )](dst=dst, num_warps=4) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) def test_unary_math(func_str, device): @@ -4439,6 +4486,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) def test_for_iv(lo, hi, iv, device): @@ -4458,6 +4506,7 @@ def kernel(Out, lo, hi, iv: tl.constexpr): assert out[0] == sum(range(lo, hi, iv)) +@pytest.mark.cpu @pytest.mark.interpreter def test_if_else(device): @@ -4483,6 +4532,7 @@ def kernel(Cond, TrueVal, FalseVal, Out): assert to_numpy(out)[0] == false_val[0] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["dynamic", "static"]) def test_if_return(mode, device): @@ -4542,6 +4592,7 @@ def add_fn_static_cond(x, cond: tl.constexpr): return x + 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "call_type", @@ -4611,6 +4662,7 @@ def kernel(Out, call_type: tl.constexpr): assert to_numpy(out)[0] == 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("_cond1", [True, False]) @pytest.mark.parametrize("_cond2", [True, False]) @@ -4653,6 +4705,7 @@ def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): assert out[0] == targets[(_cond1, _cond2, _cond3)] +@pytest.mark.cpu @pytest.mark.interpreter def test_while(device): @@ -4681,6 +4734,7 @@ def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): assert out_j[0] == bound[0] +@pytest.mark.cpu @pytest.mark.interpreter def test_nested_while(device): @@ -4979,6 +5033,7 @@ def do_test(src_layout, dst_layout): do_test(mma_pair[1], mma_pair[0]) +@pytest.mark.cpu @pytest.mark.interpreter def test_load_scalar_with_mask(device): @@ -5267,6 +5322,7 @@ def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_static_range(device): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 82e72f6395ed..a12b1d235b7c 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -645,13 +645,6 @@ def run(self, *args, grid, warmup, **kwargs): sigvals = sig_and_spec[:len(sigkeys)] signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - # The CPU launcher will provide the grid ids directly to the kernel. - # Note that this design is interim and subject to change. - if target.backend == 'cpu': - signature["__grid0"] = 'i32' - signature["__grid1"] = 'i32' - signature["__grid2"] = 'i32' - configs = (self._get_config(*bound_vals), ) constants = { p.name: v diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 683889547b0a..d8be71ad6c11 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,3 +1,8 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) + target_link_libraries(TritonCPU PUBLIC MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3c293cdf468f..357b5f448fe9 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -20,6 +20,8 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee",) + allow_fp8e4nv: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -40,7 +42,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "exe" + self.binary_ext = "bc" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -62,7 +64,6 @@ def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) - passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -77,33 +78,34 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttir.add_convert_to_ttcpuir(pm) - - # - # TODO: - # - + cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) pm.run(mod) + metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod @staticmethod def make_llir(src, metadata, options): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + metadata["threads_per_warp"] = 1 mod = src # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - - cpu.passes.ttcpuir.add_to_llvmir(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - - passes.convert.add_scf_to_cf(pm) - passes.convert.add_cf_to_llvmir(pm) + cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) + passes.convert.add_math_to_llvmir(pm) + cpu.passes.ttcpuir.add_math_to_libm(pm) + cpu.passes.ttcpuir.add_vector_to_llvmir(pm) + cpu.passes.ttcpuir.add_memref_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) + cpu.passes.ttcpuir.add_func_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -111,45 +113,40 @@ def make_llir(src, metadata, options): passes.llvmir.add_di_scope(pm) pm.run(mod) + # Find kernel fn + kernel_names = cpu.find_kernel_names(mod) + assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - - # TODO: - if not llvm_mod: - metadata["shared"] = 0 - return src - - if options.extern_libs: - paths = [path for (name, path) in options.extern_libs] - llvm.link_extern_libs(llvm_mod, paths) + llvm.set_host_target(llvm_mod) + #if options.extern_libs: + # paths = [path for (name, path) in options.extern_libs] + # llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - - # CPU doesn't have SMEM, but just to make it work for now. + # Get some metadata metadata["shared"] = 0 - - # Cleanup + metadata["name"] = kernel_names[0] ret = str(llvm_mod) del llvm_mod del context return ret @staticmethod - def make_exe(src, metadata, options): - # Just a quick hack while developing the backend. - names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) - assert len(names) == 1 - metadata["name"] = names[0] - - # TODO: Call llc to create an executable. - return src + def make_bc(src, metadata, options): + if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": + print("********** Module ASM **********") + print(llvm.translate_to_host_asm(src)) + ret = llvm.translate_to_bc(src) + return ret def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) + stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp new file mode 100644 index 000000000000..babff3dfdebe --- /dev/null +++ b/third_party/cpu/backend/driver.cpp @@ -0,0 +1,224 @@ +//===- driver.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/TargetSelect.h" + +#include +#include +#include +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + return Py_BuildValue("{s:i}", "max_shared_mem", 0); +} + +bool getBoolEnv(const std::string &env) { + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} + +llvm::orc::ThreadSafeContext &getThreadSafeContext() { + static llvm::orc::ThreadSafeContext tsc; + static std::once_flag init_flag; + std::call_once(init_flag, []() { + auto context = std::make_unique(); + tsc = llvm::orc::ThreadSafeContext(std::move(context)); + }); + return tsc; +} + +std::string llvmErrToString(const llvm::Error &err) { + std::string res; + llvm::raw_string_ostream os(res); + os << err; + return res; +}; + +struct CompiledKernel { + std::unique_ptr execution_session; + std::unique_ptr data_layout; + std::unique_ptr mangle; + std::unique_ptr object_layer; + std::unique_ptr compiler_layer; + llvm::orc::JITDylib *dylib = nullptr; + + CompiledKernel() = default; + CompiledKernel(CompiledKernel &&) = default; + + ~CompiledKernel() { + if (execution_session) + llvm::cantFail(execution_session->endSession()); + } +}; + +std::vector> compiled_kernels; + +static PyObject *loadBitcode(PyObject *self, PyObject *args) { + const char *name; + int shared; + PyObject *py_bytes; + int devId; + + if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { + std::cerr << "loadBitcode arg parse failed" << std::endl; + return NULL; + } + + std::string kernel_name = name; + size_t binary_size = PyBytes_Size(py_bytes); + const char *binary_ptr = PyBytes_AsString(py_bytes); + + llvm::LLVMContext context; + auto buf = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(binary_ptr, binary_size)); + auto mod = llvm::parseBitcodeFile(*buf, context); + if (!mod) { + std::cerr << "Failed to parse LLVM bitcode module" << std::endl; + return NULL; + } + + if (getBoolEnv("MLIR_ENABLE_DUMP")) { + llvm::errs() << "********** Loaded Module (kernel_name=" << name + << ") **********\n" + << **mod << "\n"; + } + + auto init_err = llvm::InitializeNativeTarget(); + if (init_err) { + std::cerr << "Failed to initialize native target." << std::endl; + return NULL; + } + + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + auto self_epc = + llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); + + auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!detect_host_res) { + std::cerr << "Failed to initialize JITTargetMachineBuilder: " + << llvmErrToString(detect_host_res.takeError()); + return NULL; + } + llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); + + auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); + if (!data_layout_res) { + std::cerr << "Failed to initialize data layout: " + << llvmErrToString(data_layout_res.takeError()); + return NULL; + } + + CompiledKernel kernel; + kernel.execution_session = + std::make_unique(std::move(self_epc)); + kernel.data_layout = + std::make_unique(std::move(*data_layout_res)); + kernel.mangle = std::make_unique( + *kernel.execution_session, *kernel.data_layout); + kernel.object_layer = std::make_unique( + *kernel.execution_session, + []() { return std::make_unique(); }); + kernel.compiler_layer = std::make_unique( + *kernel.execution_session, *kernel.object_layer, + std::make_unique(std::move(tmb))); + + auto dylib_res = kernel.execution_session->createJITDylib("
"); + if (!dylib_res) { + std::cerr << "Failed to create initialize JITDylib: " + << llvmErrToString(dylib_res.takeError()); + return NULL; + } + + kernel.dylib = &(*dylib_res); + kernel.dylib->addGenerator(llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + kernel.data_layout->getGlobalPrefix()))); + + // Compile module. + (**mod).setDataLayout(*kernel.data_layout); + llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); + auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); + if (err) { + std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); + return NULL; + } + + // Find kernel function pointer. + auto lookup_res = + kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); + if (!lookup_res) { + std::cerr << "Failed to find function " << std::string(name) + << "\nError: " << llvmErrToString(lookup_res.takeError()); + return NULL; + } + uint64_t fn_ptr = lookup_res->getAddress().getValue(); + + compiled_kernels.push_back( + std::make_unique(std::move(kernel))); + auto *kernel_ptr = compiled_kernels.back().get(); + + return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), + reinterpret_cast(fn_ptr), 0, 0); +} + +static PyObject *initContext(PyObject *self, PyObject *args) { + return Py_BuildValue("(K)", (uint64_t)0); +} + +static PyObject *initDevices(PyObject *self, PyObject *args) { + return Py_BuildValue("(i)", 1); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBitcode, METH_VARARGS, + "Load provided SPV into ZE driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cpu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3f3816a99b9f..743684d2640f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,5 +1,100 @@ +import os +import hashlib +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -from triton.backends.driver import CPUDriverBase + +dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") +llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") +llvm_root = os.path.expanduser(llvm_root) +llvm_dirs = os.listdir(llvm_root) +if len(llvm_dirs) == 1: + llvm_root = os.path.join(llvm_root, llvm_dirs[0]) +include_dir = [ + os.path.join(dirname, "include"), + os.path.join(llvm_root, "include"), +] +library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] +libraries = [ + "LLVMOrcJIT", + "LLVMPasses", + "LLVMX86CodeGen", + "LLVMX86AsmParser", + "LLVMX86Desc", + "LLVMX86Info", + "LLVMGlobalISel", + "LLVMSelectionDAG", + "LLVMHipStdPar", + "LLVMCoroutines", + "LLVMipo", + "LLVMFrontendOpenMP", + "LLVMInstrumentation", + "LLVMAsmPrinter", + "LLVMCodeGen", + "LLVMObjCARCOpts", + "LLVMLinker", + "LLVMVectorize", + "LLVMScalarOpts", + "LLVMInstCombine", + "LLVMFrontendOffloading", + "LLVMExecutionEngine", + "LLVMAggressiveInstCombine", + "LLVMTransformUtils", + "LLVMTarget", + "LLVMRuntimeDyld", + "LLVMJITLink", + "LLVMIRPrinter", + "LLVMBitWriter", + "LLVMAnalysis", + "LLVMProfileData", + "LLVMSymbolize", + "LLVMDebugInfoDWARF", + "LLVMObject", + "LLVMTextAPI", + "LLVMMCParser", + "LLVMMCDisassembler", + "LLVMMC", + "LLVMIRReader", + "LLVMCFGuard", + "LLVMBitReader", + "LLVMAsmParser", + "LLVMCore", + "LLVMBinaryFormat", + "LLVMOrcTargetProcess", + "LLVMTargetParser", + "LLVMRemarks", + "LLVMOrcShared", + "LLVMOption", + "LLVMDebugInfoCodeView", + "LLVMCodeGenTypes", + "LLVMBitstreamReader", + "LLVMSupport", + "LLVMDemangle", + "stdc++", +] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + # ------------------------ # Utils @@ -15,22 +110,12 @@ def __new__(cls): def __init__(self): pass + dirname = os.path.dirname(os.path.realpath(__file__)) + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") + self.load_binary = mod.load_binary - @staticmethod - def get_device_properties(device): - # This is just dummy for now. We will need to implement driver.c. - return { - "max_shared_mem": 0, - "multiprocessor_count": 0, - "sm_clock_rate": 0, - "mem_clock_rate": 0, - "mem_bus_width": 0, - } - - @staticmethod - def load_binary(name, kernel_asm, shared, device): - # This is just dummy for now. We will need to implement driver.c. - return (None, kernel_asm, 0, 0) + def get_device_properties(self, *args): + return {"max_shared_mem": 0} # ------------------------ @@ -38,27 +123,228 @@ def load_binary(name, kernel_asm, shared, device): # ------------------------ +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def make_launcher(constants, signature, ids): - pass + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + arg_types = (', '.join(f"{ty_to_cpp(ty)}" for i, ty in signature.items()) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOKOOOO" + args_format + args_list = ', '.join(f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + src = f""" +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include +#include + +using kernel_ptr_t = void(*)({arg_types}); + +typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: add OMP pragmas to run in parallel + for (uint32_t z = 0; z < gridZ; ++z) {{ + for (uint32_t y = 0; y < gridY; ++y) {{ + for (uint32_t x = 0; x < gridX; ++x) {{ + (*kernel_ptr)({args_list + ', ' if len(arg_decls) > 0 else ''} x, y, z); + }} + }} + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + + + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *py_obj_stream; + void* pKrnl; + + {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ + return NULL; + }} + + void *pStream = PyLong_AsVoidPtr(py_obj_stream); + kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_cpu_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src class CPULauncher(object): def __init__(self, src, metadata): - # TODO: - self.launch = lambda *args, **kwargs: None + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_cpu_launcher") + self.launch = mod.launch def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CPUDriver(CPUDriverBase): +class CPUDriver(DriverBase): def __init__(self): self.utils = CPUUtils() self.launcher_cls = CPULauncher super().__init__() + def get_current_device(self): + return 0 + + def get_current_stream(self, device): + return 0 + def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt new file mode 100644 index 000000000000..fc9a19e52b0d --- /dev/null +++ b/third_party/cpu/include/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..64b36523d35d --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) +add_public_tablegen_target(TritonCPUConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h new file mode 100644 index 000000000000..74f74b00870c --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -0,0 +1,36 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H +#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +std::unique_ptr> createFuncOpToLLVMPass(); +std::unique_ptr> createMemoryOpToLLVMPass(); +std::unique_ptr> createGetProgramIdOpToLLVMPass(); + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); +void registerTritonCPUToLLVMPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td new file mode 100644 index 000000000000..c75b58b572f1 --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_PASSES +#define TRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert FuncOp to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton memory operations to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton GetProgramId to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..56e231273ed6 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) +add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h new file mode 100644 index 000000000000..745799039691 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -0,0 +1,38 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H +#define TRITONTOTRITONCPU_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> createConvertPtrOps(); +std::unique_ptr> createConvertDotOp(); +std::unique_ptr> createConvertControlFlowOps(); + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); +void registerTritonToTritonCPUPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td new file mode 100644 index 000000000000..5f52f3a2e31d --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -0,0 +1,77 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES +#define TRITONTOTRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton memory ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { + let summary = "Convert elementwise ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton ops related to pointer arithmetics."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertPtrOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDotOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertControlFlowOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt new file mode 100644 index 000000000000..fc9a19e52b0d --- /dev/null +++ b/third_party/cpu/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..884c9352ef1b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonCPUToLLVM + FuncOpToLLVM.cpp + GetProgramIdOpToLLVM.cpp + MemoryOpToLLVM.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonCPUToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRVectorToLLVMPass +) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000000..4c5257fcff4c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,278 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_FUNCOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + // 1. Modify the function type to add new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add new arguments. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + SmallVector amendedArgAttrs; + if (funcOp.getAllArgAttrs()) + amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto modifiedFuncOp = funcOp; + if (LLVM::isKernel(funcOp)) + modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + modifiedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) + return failure(); + + // required by AxisInfoAnalysis + if (LLVM::isKernel(funcOp)) + rewriter.eraseOp(modifiedFuncOp); + rewriter.eraseOp(funcOp); + return success(); + } +}; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto funcOp = op->getParentOfType(); + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = + insert_val(packedResultsTy, packedResults, it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { + using FuncOpToLLVMBase::FuncOpToLLVMBase; + + FuncOpToLLVM() : FuncOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + // Lower tt.func + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, + /*benefit=*/1); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + // Lower tt.call, tt.return + int benefit = 10; + RewritePatternSet patterns(context); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createFuncOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp new file mode 100644 index 000000000000..4c593f1ff7aa --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -0,0 +1,98 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct GetProgramIdOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); + auto args = funcOp.getArguments(); + // Last three args are x, y, z program ids. + auto argIdx = args.size() - 3 + op.getAxisAsInt(); + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + rewriter.replaceOp(op, args[argIdx]); + return success(); + } +}; + +struct GetProgramIdOpToLLVM + : public triton::impl::GetProgramIdOpToLLVMBase { + using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; + + GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createGetProgramIdOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000000..594495c4ab9d --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,277 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MEMORYOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct ExtractMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto memRefTy = cast(op.getType()); + auto rank = memRefTy.getRank(); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + auto memRefStructFields = + cast(memRefStructTy).getBody(); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { + auto valueTy = memRefStructFields[idxTo]; + Value val = rewriter.create( + loc, valueTy, tensorPtrStruct, idxFrom); + return rewriter.create(loc, memRefStructTy, to, val, + idxTo); + }; + + Value res = undef(memRefStructTy); + // Copy base. + res = copyValue(res, 0, 1); + // Use 0 offset. + res = rewriter.create(loc, memRefStructTy, res, + i64_val(0), 2); + // Copy shape. + res = copyValue(res, 2, 3); + // Copy strides. + res = copyValue(res, 3, 4); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct ExtractIndicesOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto rank = op.getNumResults(); + auto i64Ty = IntegerType::get(getContext(), 64); + SmallVector indices; + + for (int64_t i = 0; i < rank; i++) { + Value offs = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); + Value stride = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); + indices.push_back(rewriter.create(loc, offs, stride)); + } + + rewriter.replaceOp(op, indices); + + return success(); + } +}; + +struct MakeTensorPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto structTy = getTypeConverter()->convertType(op.getType()); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto insertArray = [&](Value structVal, auto values, int64_t idx, + Type zextTo = nullptr) { + for (int64_t i = 0; i < static_cast(values.size()); ++i) { + Value val = values[i]; + if (zextTo) + val = rewriter.create(loc, zextTo, val); + structVal = rewriter.create( + loc, structTy, structVal, val, SmallVector{idx, i}); + } + return structVal; + }; + + Value res = undef(structTy); + // 0 - base pointer. + auto base = rewriter.getRemappedValue(op.getBase()); + res = rewriter.create(loc, structTy, res, base, 0); + // 1 - array for offsets. Promote values to i64. + res = insertArray(res, op.getOffsets(), 1, i64Ty); + // 2 - array for shape. + res = insertArray(res, op.getShape(), 2); + // 3 - array for strides. + res = insertArray(res, op.getStrides(), 3); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct AdvanceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i64Ty = IntegerType::get(getContext(), 64); + Value res = rewriter.getRemappedValue(op.getPtr()); + Type structTy = res.getType(); + auto offsets = op.getOffsets(); + + for (int64_t i = 0; i < offsets.size(); ++i) { + auto oldOffset = rewriter.create( + loc, i64Ty, res, SmallVector{1, i}); + auto step = rewriter.create(loc, i64Ty, offsets[i]); + auto newOffset = rewriter.create(loc, oldOffset, step); + res = rewriter.create(loc, structTy, res, newOffset, + SmallVector{1, i}); + } + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type ptrTy = LLVM::LLVMPointerType::get(getContext()); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, + op.getIsVolatile()); + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value val = rewriter.getRemappedValue(op.getValue()); + rewriter.replaceOpWithNewOp(op, val, ptr); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct MemoryOpToLLVM + : public triton::impl::MemoryOpToLLVMBase { + using MemoryOpToLLVMBase::MemoryOpToLLVMBase; + + MemoryOpToLLVM() : MemoryOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createMemoryOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp new file mode 100644 index 000000000000..914f56e668f8 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -0,0 +1,25 @@ +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonCPUToLLVMPipeline() { + PassPipelineRegistration<>("triton-cpu-to-llvmir", + "TritonCPU to LLVM conversion pipeline.", + tritonCPUToLLVMPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000000..144cb57b1115 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -0,0 +1,43 @@ +#include "TypeConverter.h" + +using namespace mlir; +using namespace mlir::triton; + +TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([this](RankedTensorType tensorTy) -> std::optional { + if (isa(tensorTy.getElementType())) + return VectorType::get(tensorTy.getShape(), + IntegerType::get(tensorTy.getContext(), 64)); + return std::nullopt; + }); +} + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + // struct { + // ptr base_ptr; + // array offsets; + // array shape; + // array strides; + // } + auto tensorTy = cast(pointeeType); + auto rank = tensorTy.getShape().size(); + auto i64Ty = IntegerType::get(ctx, 64); + SmallVector types; + types.push_back(LLVM::LLVMPointerType::get(ctx)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h new file mode 100644 index 000000000000..35d74a9ec430 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..997fb748878a --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonToTritonCPU + ConvertControlFlowOps.cpp + ConvertDotOp.cpp + ConvertElementwiseOps.cpp + ConvertMemoryOps.cpp + ConvertPtrOps.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonToTritonCPUPassIncGen + + LINK_LIBS PUBLIC + TritonCPUIR + MLIRVectorDialect +) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp new file mode 100644 index 000000000000..9cf6e31810d7 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -0,0 +1,121 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTCONTROLFLOWOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ControlFlowOpConversionTarget : public ConversionTarget { +public: + explicit ControlFlowOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +struct ForOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lowerBound = rewriter.getRemappedValue(op.getLowerBound()); + Value upperBound = rewriter.getRemappedValue(op.getUpperBound()); + Value step = rewriter.getRemappedValue(op.getStep()); + SmallVector initArgs; + if (failed(rewriter.getRemappedValues(op.getInitArgs(), initArgs))) + return failure(); + // Create new for op with remapped values. + auto newOp = rewriter.create(op.getLoc(), lowerBound, + upperBound, step, initArgs); + // Move the old op block and convert its sigature. + Block *oldBlock = op.getBody(); + Block *newBlock = newOp.getBody(); + rewriter.moveBlockBefore(oldBlock, newOp.getBody()); + rewriter.eraseBlock(newBlock); + if (failed(rewriter.convertRegionTypes(oldBlock->getParent(), + *getTypeConverter()))) + return failure(); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; + +struct ConvertControlFlowOps + : public triton::impl::ConvertControlFlowOpsBase { + using ConvertControlFlowOpsBase::ConvertControlFlowOpsBase; + + ConvertControlFlowOps() : ConvertControlFlowOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ControlFlowOpConversionTarget convTarget(*context, typeConverter); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add>(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertControlFlowOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp new file mode 100644 index 000000000000..b6fbb1893202 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -0,0 +1,102 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDOTOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Value a = rewriter.getRemappedValue(op.getA()); + Value b = rewriter.getRemappedValue(op.getB()); + Value c = rewriter.getRemappedValue(op.getC()); + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } +}; + +struct ConvertDotOp : public triton::impl::ConvertDotOpBase { + using ConvertDotOpBase::ConvertDotOpBase; + + ConvertDotOp() : ConvertDotOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp new file mode 100644 index 000000000000..f40ab9839084 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -0,0 +1,235 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElementwiseOpConversionTarget : public ConversionTarget { +public: + explicit ElementwiseOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ConstantOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + assert(resTy); + if (auto denseAttr = dyn_cast(op.getValueAttr())) { + rewriter.replaceOpWithNewOp(op, resTy, + denseAttr.reshape(resTy)); + } else { + llvm_unreachable("Unexpected constant attribute"); + } + return success(); + } +}; + +struct ReshapeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct ConvertElementwiseOps + : public triton::impl::ConvertElementwiseOpsBase { + using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; + + ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElementwiseOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElementwiseOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp new file mode 100644 index 000000000000..1679ecc7af90 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -0,0 +1,277 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTMEMORYOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto mask = loadOp.getMask(); + auto ptr = loadOp.getPtr(); + auto boundaryChecks = loadOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + return lowerToScalarLoads(loadOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported load op"); + } + + auto memRef = rewriter.getRemappedValue(ptr); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecRead = rewriter.create(loc, resTy, memRef, + indices, inBounds); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } + + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // Scalar loads and boundary checks are not expected. + assert(loadOp.getBoundaryCheck().empty()); + assert(isa(loadOp.getType())); + + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); + auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + auto ptrTy = + dyn_cast(loadOp.getPtr().getType()).getElementType(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + Value defaultVal = loadOp.getOther(); + if (!defaultVal) + defaultVal = rewriter.create( + loc, rewriter.getZeroAttr(vecTy.getElementType())); + Value dst = rewriter.create(loc, vecTy, defaultVal); + + int64_t numElems = vecTy.getNumElements(); + auto strides = computeStrides(vecTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + Value origDst = dst; + // Create a conditional block for load if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = + rewriter.create(loc, ptr, cache, evict, isVolatile); + dst = rewriter.create(loc, val, dst, indices); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + Value resDst = dst; + dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock, origDst); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock, resDst); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.replaceOp(loadOp, dst); + + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto mask = storeOp.getMask(); + auto ptr = storeOp.getPtr(); + auto boundaryChecks = storeOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + return lowerToScalarStores(storeOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported store op"); + } + + auto value = rewriter.getRemappedValue(storeOp.getValue()); + auto memRef = rewriter.getRemappedValue(ptr); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecWrite = rewriter.create(loc, value, memRef, + indices, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } + + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // Scalar stores and boundary checks are not expected. + assert(storeOp.getBoundaryCheck().empty()); + assert(isa(storeOp.getValue().getType())); + + auto loc = storeOp.getLoc(); + auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); + auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto tensorTy = dyn_cast(storeOp.getPtr().getType()); + auto ptrTy = tensorTy.getElementType(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + int64_t numElems = tensorTy.getNumElements(); + auto strides = computeStrides(tensorTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + // Create a conditional block for store if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + rewriter.create(loc, ptr, val, cache, evict); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.eraseOp(storeOp); + + return success(); + } +}; + +class MemoryOpConversionTarget : public ConversionTarget { +public: + explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Allow only scalar loads and stores. + addDynamicallyLegalOp([](triton::LoadOp loadOp) { + return loadOp.getType().isIntOrIndexOrFloat(); + }); + addDynamicallyLegalOp([](triton::StoreOp storeOp) { + return storeOp.getValue().getType().isIntOrIndexOrFloat(); + }); + } +}; + +struct ConvertMemoryOps + : public triton::impl::ConvertMemoryOpsBase { + using ConvertMemoryOpsBase::ConvertMemoryOpsBase; + + ConvertMemoryOps() : ConvertMemoryOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + MemoryOpConversionTarget convTarget(*context); + TritonToTritonCPUTypeConverter pointerConverter; + RewritePatternSet patterns(context); + patterns.add(pointerConverter, context); + patterns.add(pointerConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertMemoryOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp new file mode 100644 index 000000000000..ade8b858bbfb --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -0,0 +1,191 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTPTROPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +unsigned getElemBitWidth(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + if (auto vectorTy = dyn_cast(type)) + return vectorTy.getElementType().getIntOrFloatBitWidth(); + return type.getIntOrFloatBitWidth(); +} + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Allow only scalar pointer conversion. + addDynamicallyLegalOp( + [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); + addDynamicallyLegalOp([](triton::IntToPtrOp op) { + return op.getSrc().getType().isInteger(); + }); + } +}; + +struct MakeRangeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int32_t start = static_cast(op.getStart()); + int32_t end = static_cast(op.getEnd()); + assert(end >= start); + + llvm::SmallVector values; + values.reserve(end - start); + for (int32_t v = start; v < end; ++v) { + values.push_back(v); + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.create( + op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct SplatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value val = op.getSrc(); + Type dstValType = getTypeConverter()->convertType(val.getType()); + // Cast pointer + if (isa(val.getType())) + val = rewriter + .create( + loc, getTypeConverter()->convertType(val.getType()), val) + .getResult(); + Type resType = getTypeConverter()->convertType(op.getType()); + auto cast = rewriter.create(loc, resType, val); + + rewriter.replaceOp(op, cast); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + unsigned offsetBitWidth = getElemBitWidth(offset.getType()); + unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); + // Compute scale. i1 elements take 1 byte. + Value scale = rewriter.create( + loc, (elemBitWidth + 7) / 8, offsetBitWidth); + if (isa(offset.getType())) + scale = rewriter.create(loc, offset.getType(), scale); + offset = rewriter.create(loc, offset, scale); + offset = rewriter.create(loc, ptr.getType(), offset); + rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { + using ConvertPtrOpsBase::ConvertPtrOpsBase; + + ConvertPtrOps() : ConvertPtrOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertPtrOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h new file mode 100644 index 000000000000..aaac6a27d5e6 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h @@ -0,0 +1,37 @@ +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +// Generic pattern to rewrite operation by converting types +// for operation operands and results using provided type +// converter. +template +struct OpTypeConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + OperationState newState(op.getLoc(), ResOpT::getOperationName()); + // Convert operands. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.getRemappedValue(operand); + newState.operands.push_back(newOperand); + } + // Convert result types. + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newState.types))) { + return failure(); + } + newState.attributes = op->getAttrs(); + + auto newOp = rewriter.create(newState); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp new file mode 100644 index 000000000000..d954142d9172 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -0,0 +1,27 @@ +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); + pm.addPass(mlir::triton::cpu::createConvertPtrOps()); + pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + pm.addPass(mlir::triton::cpu::createConvertDotOp()); + pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonToTritonCPUPipeline() { + PassPipelineRegistration<>("triton-to-triton-cpu", + "Triton to TritonCPU conversion pipeline.", + tritonToTritonCPUPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp new file mode 100644 index 000000000000..07b2da0468ba --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -0,0 +1,51 @@ +#include "TypeConverter.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrTy) -> Type { + if (triton::isTensorPointerType(ptrTy)) { + // Tensor pointer is translated into a memref + auto tensorTy = dyn_cast(ptrTy.getPointeeType()); + auto elemTy = tensorTy.getElementType(); + // TODO: use dynamic strides + SmallVector shape(tensorTy.getRank(), ShapedType::kDynamic); + return MemRefType::get(shape, elemTy); + } + return IntegerType::get(ptrTy.getContext(), 64); + }); + addConversion([this](RankedTensorType tensorTy) -> Type { + Type elemTy = convertType(tensorTy.getElementType()); + return VectorType::get(tensorTy.getShape(), elemTy); + }); + + // Converted ops produce vectors instead of tensors. Provide conversion + // here for users. Also, convert pointers when required. + addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + if (isa(type)) + return builder.create(loc, type, inputs); + return builder.create(loc, type, inputs) + .getResult(0); + }); + + // Converted loads and stores consume memrefs instead of pointers, use extract + // op to get them. Also, provide conversion for vector users and pointer + // casts. + addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + if (type.isInteger() && isa(inputs.front().getType())) + return builder.create(loc, type, inputs); + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + if (isa(type)) + return builder.create(loc, type, inputs); + llvm_unreachable("Unexpected target materizalization"); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h new file mode 100644 index 000000000000..cb89f0886c60 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H + +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonToTritonCPUTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + TritonToTritonCPUTypeConverter(); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 302951d04d59..efc949d6f4a1 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,9 +1,20 @@ +#include "TritonCPUToLLVM/Passes.h" +#include "TritonToTritonCPU/Passes.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" + #include #include #include @@ -14,8 +25,26 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // m.def("add_to_llvmir", [](mlir::PassManager &pm) { + // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // }); + m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); + }); + m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); + }); + m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertVectorToLLVMPass()); + }); + m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + }); + m.def("add_math_to_libm", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertMathToLibmPass()); + }); + m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertFuncToLLVMPass()); }); } @@ -25,8 +54,18 @@ void init_triton_cpu(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); + + m.def("find_kernel_names", [](mlir::ModuleOp &mod) { + std::vector res; + mod.walk([&](mlir::FunctionOpInterface funcOp) { + if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) + res.push_back(funcOp.getName().str()); + }); + return res; + }); } From c332393fdbfc5e4d32a2dad762634870dcaafd39 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 13:15:06 -0700 Subject: [PATCH 02/13] Use axis info in memory op lowering. Signed-off-by: Ilya Enkovich --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 16 ++ .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 20 +++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 162 +++++++++++++++++- 3 files changed, 192 insertions(+), 6 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index bb7417ebd03e..712826d02f91 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -60,4 +60,20 @@ def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; } +def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> { + let summary = "Build a memref for a pointer."; + + let description = [{ + Build memref with static shape, offset, strides, and specified base pointer. + }]; + + let arguments = (ins TT_Ptr:$src); + + let results = (outs AnyStaticShapeMemRef:$result); + + let hasCanonicalizer = 0; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp index 594495c4ab9d..7bd602dc81a7 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -111,6 +111,25 @@ struct ExtractIndicesOpConversion } }; +struct PtrToMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PtrToMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getSrc()); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + + Value res = undef(memRefStructTy); + res = + rewriter.create(loc, memRefStructTy, res, ptr, 1); + rewriter.replaceOp(op, res); + + return success(); + } +}; + struct MakeTensorPtrOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -256,6 +275,7 @@ struct MemoryOpToLLVM patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 1679ecc7af90..394289063f14 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -32,8 +32,32 @@ using namespace mlir::triton::cpu; namespace { -struct LoadOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct MemoryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context), + axisAnalysis(axisInfoAnalysis) {} + + Value extractScalarPointer(Location loc, Value ptrs, + ArrayRef indices, + ConversionPatternRewriter &rewriter) const { + // TODO: Analyze data flow and build scalar pointer computation code. + Value ptr = rewriter.create( + loc, rewriter.getRemappedValue(ptrs), indices); + auto ptrTy = dyn_cast(ptrs.getType()).getElementType(); + ptr = rewriter.create(loc, ptrTy, ptr); + return ptr; + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysis; +}; + +struct LoadOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; LogicalResult matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, @@ -44,6 +68,10 @@ struct LoadOpConversion : public OpConversionPattern { auto boundaryChecks = loadOp.getBoundaryCheck(); if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (axisInfo) { + return lowerUsingAxisInfo(axisInfo, loadOp, rewriter); + } return lowerToScalarLoads(loadOp, rewriter); } @@ -67,6 +95,70 @@ struct LoadOpConversion : public OpConversionPattern { return success(); } + LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores? + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto shape = vecTy.getShape(); + auto contiguity = axisInfo->getContiguity(); + if (shape.back() > 1 && shape.back() == contiguity.back()) { + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type subVecTy = VectorType::get(shape.back(), vecTy.getElementType()); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = loadOp.getMask() + ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + Value defaultVal = loadOp.getOther(); + if (!defaultVal) + defaultVal = rewriter.create( + loc, rewriter.getZeroAttr(vecTy.getElementType())); + Value res = rewriter.create(loc, vecTy, defaultVal); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + auto ptr = + extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + Value vec; + if (mask) { + Value subMask = mask; + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + subMask = rewriter.create(loc, mask, subIndices); + } + Value passThru = + rewriter.create(loc, subVecTy, defaultVal); + vec = rewriter.create( + loc, subVecTy, memRef, zeroIdx, subMask, passThru); + } else { + vec = rewriter.create(loc, subVecTy, memRef, zeroIdx); + } + + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + res = rewriter.create(loc, vec, res, subIndices); + } else { + res = vec; + } + } + + rewriter.replaceOp(loadOp, res); + return success(); + } + + return lowerToScalarLoads(loadOp, rewriter); + } + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, ConversionPatternRewriter &rewriter) const { // Scalar loads and boundary checks are not expected. @@ -133,8 +225,8 @@ struct LoadOpConversion : public OpConversionPattern { } }; -struct StoreOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct StoreOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; LogicalResult matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, @@ -145,6 +237,10 @@ struct StoreOpConversion : public OpConversionPattern { auto boundaryChecks = storeOp.getBoundaryCheck(); if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (axisInfo) { + return lowerUsingAxisInfo(axisInfo, storeOp, rewriter); + } return lowerToScalarStores(storeOp, rewriter); } @@ -167,6 +263,57 @@ struct StoreOpConversion : public OpConversionPattern { return success(); } + LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores instead? + auto loc = storeOp.getLoc(); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto vecTy = dyn_cast(vals.getType()); + auto shape = vecTy.getShape(); + auto contiguity = axisInfo->getContiguity(); + if (shape.back() > 1 && shape.back() == contiguity.back()) { + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + auto ptr = + extractScalarPointer(loc, storeOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + indices.pop_back(); + auto val = rewriter.create(loc, vals, indices); + + if (mask) { + Value subMask = mask; + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + subMask = rewriter.create(loc, mask, indices); + } + rewriter.create(loc, memRef, zeroIdx, subMask, + val); + } else { + rewriter.create(loc, val, memRef, zeroIdx); + } + } + + rewriter.eraseOp(storeOp); + return success(); + } + + return lowerToScalarStores(storeOp, rewriter); + } + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, ConversionPatternRewriter &rewriter) const { // Scalar stores and boundary checks are not expected. @@ -226,6 +373,7 @@ class MemoryOpConversionTarget : public ConversionTarget { explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); + addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalDialect(); @@ -251,11 +399,13 @@ struct ConvertMemoryOps MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); MemoryOpConversionTarget convTarget(*context); TritonToTritonCPUTypeConverter pointerConverter; RewritePatternSet patterns(context); - patterns.add(pointerConverter, context); - patterns.add(pointerConverter, context); + patterns.add(axisInfoAnalysis, pointerConverter, context); + patterns.add(axisInfoAnalysis, pointerConverter, + context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 14f82d792820216225772f350366dc10e8d275fc Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 13:20:16 -0700 Subject: [PATCH 03/13] Mark test_ptx_cast as enabled for CPU. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b9bab8b3e21f..4080eda0b624 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5052,6 +5052,7 @@ def kernel(Input, Index, Out, N: int): # This test is used to test our own PTX codegen for float16 and int16 conversions # maybe delete it later after ptxas has been fixed +@pytest.mark.cpu @pytest.mark.parametrize("dtype_str", ['float16', 'int16']) def test_ptx_cast(dtype_str, device): From 6f4ae616a1929a361bc0135b2455f59a65fc20cd Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 14:30:17 -0700 Subject: [PATCH 04/13] Support umulhi operation. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 1 + .../ConvertElementwiseOps.cpp | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4080eda0b624..507909ea58b1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1736,6 +1736,7 @@ def kernel(in_out_ptr): assert torch.all(x == 2) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['int32']) def test_umulhi(dtype_str, device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index f40ab9839084..d6ce052ebc35 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -55,6 +55,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -142,6 +143,39 @@ struct ReshapeOpConversion : public OpConversionPattern { } }; +struct MulhiUIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getX()); + auto rhs = rewriter.getRemappedValue(op.getY()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + auto vecI32Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI32Type()); + auto vecI64Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI64Type()); + assert(lhsTy.getElementType().isInteger()); + assert(rhsTy.getElementType().isInteger()); + // Cast to int64 + if (lhsTy.getElementTypeBitWidth() < 64) { + lhs = rewriter.create(loc, vecI64Ty, lhs); + } + if (rhsTy.getElementTypeBitWidth() < 64) { + rhs = rewriter.create(loc, vecI64Ty, rhs); + } + Value res = rewriter.create(loc, lhs, rhs); + Value cst32 = rewriter.create( + loc, DenseElementsAttr::get(vecI64Ty, 32LL)); + res = rewriter.create(loc, res, cst32); + res = rewriter.create(loc, vecI32Ty, res); + rewriter.replaceOp(op, res); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -214,6 +248,7 @@ struct ConvertElementwiseOps patterns.add>( typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 0b2ac7c3cc22aab7f60083ab8848aece12006ffb Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 15:32:28 -0700 Subject: [PATCH 05/13] Support tl.clamp, tl.minimum, tl.maximum. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 3 ++ .../ConvertElementwiseOps.cpp | 33 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 507909ea58b1..89ab0e656d63 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5215,6 +5215,7 @@ def mul_add(data): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("dtype", ['float16', 'float32']) @pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) @pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) @@ -5253,6 +5254,7 @@ def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['float16', 'float32']) def test_clamp(dtype, device): @@ -5289,6 +5291,7 @@ def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexp # Test for symmetric clamp(x, -limit, limit), as it may go through optimized # codegen in the backends +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['float16', 'float32']) def test_clamp_symmetric(dtype, device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index d6ce052ebc35..1dd08d79fa6f 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -56,6 +56,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -176,6 +177,29 @@ struct MulhiUIOpConversion : public OpConversionPattern { } }; +struct ClampFOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getX()); + auto minVal = rewriter.getRemappedValue(op.getMin()); + auto maxVal = rewriter.getRemappedValue(op.getMax()); + Value res; + if (op.getPropagateNanAttr().getValue() == PropagateNan::ALL) { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } else { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -221,6 +245,14 @@ struct ConvertElementwiseOps patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); @@ -249,6 +281,7 @@ struct ConvertElementwiseOps typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 5e9facf83d371d4acdaf4c6eb5f9d33a84276a2b Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 16:33:23 -0700 Subject: [PATCH 06/13] Add enable_fp_fusion opt for CPU (only affects ASM dump now). Signed-off-by: Ilya Enkovich --- python/src/llvm.cc | 9 +++++---- third_party/cpu/backend/compiler.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index ae5798ed1fb8..ef10b3c11a13 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -294,7 +294,7 @@ void init_triton_llvm(py::module &&m) { m.def( "translate_to_host_asm", - [](std::string llvmIR) -> py::object { + [](std::string llvmIR, bool enable_fp_fusion) -> py::object { std::string res; { // when allow_threads goes out of scope, gil will be released @@ -311,9 +311,10 @@ void init_triton_llvm(py::module &&m) { "failed to parse IR: " + error.getMessage() + "lineno: " + std::to_string(error.getLineNo())); } - res = translateLLVMIRToASM( - *module, llvm::sys::getDefaultTargetTriple(), - llvm::sys::getHostCPUName().str(), "", {}, false, false); + res = + translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(), + llvm::sys::getHostCPUName().str(), "", {}, + enable_fp_fusion, false); } return py::str(res); }, diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 357b5f448fe9..73dbbe779a38 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -22,6 +22,7 @@ class CPUOptions: debug: bool = False allowed_dot_input_precisions: Tuple[str] = ("ieee",) allow_fp8e4nv: bool = False + enable_fp_fusion: bool = True # TODO: We may introduce CPU-specific options like # of cores. @@ -138,7 +139,7 @@ def make_llir(src, metadata, options): def make_bc(src, metadata, options): if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": print("********** Module ASM **********") - print(llvm.translate_to_host_asm(src)) + print(llvm.translate_to_host_asm(src, options.enable_fp_fusion)) ret = llvm.translate_to_bc(src) return ret From cad9174c33665a9c3bf60a9bffaf120bcf779a1f Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 19:09:58 -0700 Subject: [PATCH 07/13] Fix kernel args passing for propagated constants. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 1 + third_party/cpu/backend/driver.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 89ab0e656d63..463680504457 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2896,6 +2896,7 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) # TODO: bfloat16 diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 743684d2640f..3fe243fc262d 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -149,7 +149,6 @@ def make_launcher(constants, signature, ids): # Record the end of regular arguments; # subsequent arguments are architecture-specific descriptors. arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - arg_types = (', '.join(f"{ty_to_cpp(ty)}" for i, ty in signature.items()) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" def _extracted_type(ty): if ty[0] == '*': @@ -174,8 +173,10 @@ def format_of(ty): args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOKOOOO" + args_format - args_list = ', '.join(f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + kernel_fn_args = [i for i in signature.keys() if i not in constants] + kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' + kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" # generate glue code src = f""" @@ -188,7 +189,7 @@ def format_of(ty): #include #include -using kernel_ptr_t = void(*)({arg_types}); +using kernel_ptr_t = void(*)({kernel_fn_arg_types}); typedef struct _DevicePtrInfo {{ void* dev_ptr; @@ -235,7 +236,7 @@ def format_of(ty): for (uint32_t z = 0; z < gridZ; ++z) {{ for (uint32_t y = 0; y < gridY; ++y) {{ for (uint32_t x = 0; x < gridX; ++x) {{ - (*kernel_ptr)({args_list + ', ' if len(arg_decls) > 0 else ''} x, y, z); + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); }} }} }} From 224943e61f982a9d5b7f5506b15a87605ad45227 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 20 May 2024 19:40:44 -0700 Subject: [PATCH 08/13] Add permutations support. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 2 ++ .../TritonToTritonCPU/ConvertElementwiseOps.cpp | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 463680504457..21d4d22a7e69 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1114,6 +1114,7 @@ def kernel(): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_transpose(dtype_x, device): @@ -2953,6 +2954,7 @@ def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constex assert 'st.global.v4' in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 1dd08d79fa6f..e5cd372fe3e4 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -57,6 +57,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -200,6 +201,21 @@ struct ClampFOpConversion : public OpConversionPattern { } }; +struct TransOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getSrc()); + auto order = op.getOrder(); + SmallVector permutation(order.begin(), order.end()); + rewriter.replaceOpWithNewOp(op, val, permutation); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -282,6 +298,7 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 917bc8211e5cd16aa786b5d9c5e9bafdc1cc4bee Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 21 May 2024 13:43:53 -0700 Subject: [PATCH 09/13] Support 2-D transfer_read/transfer_write lowering. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 7 +++++-- third_party/cpu/CMakeLists.txt | 2 +- third_party/cpu/backend/compiler.py | 2 ++ third_party/cpu/triton_cpu.cc | 11 +++++++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 21d4d22a7e69..839e57331744 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2978,6 +2978,7 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) @@ -4094,6 +4095,7 @@ def kernel(): assert "reshape" in str(exc_info.value) +@pytest.mark.cpu def test_trans_reshape(device): @triton.jit @@ -4120,8 +4122,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) k = kernel[(1, )](input, actual, shape[0], shape[1]) - assert k.asm['ttgir'].count( - 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + if not is_cpu(): + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index d8be71ad6c11..1b08addbc9b7 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -4,5 +4,5 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) - target_link_libraries(TritonCPU PUBLIC MLIRMathToLibm) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 73dbbe779a38..344cdd2f05ae 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -98,6 +98,8 @@ def make_llir(src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) + cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index efc949d6f4a1..8065098becbe 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -34,9 +34,20 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); }); + m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, + unsigned target_rank, bool lower_tensors) { + mlir::VectorTransferToSCFOptions opts; + opts.setTargetRank(target_rank); + opts.enableFullUnroll(full_unroll); + opts.enableLowerTensors(lower_tensors); + pm.addPass(mlir::createConvertVectorToSCFPass(opts)); + }); m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertVectorToLLVMPass()); }); + m.def("add_lower_affine", [](mlir::PassManager &pm) { + pm.addPass(mlir::createLowerAffinePass()); + }); m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); From 36006b37c10a89cba47d15967d4eb33a612ab2fb Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 22 May 2024 13:30:26 -0700 Subject: [PATCH 10/13] Introduce shape info analysis and use it for loads/stores by block pointers. Delay scalar pointers lowering. Signed-off-by: Ilya Enkovich --- lib/Dialect/TritonCPU/IR/Dialect.cpp | 1 + .../cpu/include/Analysis/TensorPtrShapeInfo.h | 109 +++++++++ third_party/cpu/lib/Analysis/CMakeLists.txt | 11 + .../cpu/lib/Analysis/TensorPtrShapeInfo.cpp | 217 ++++++++++++++++++ third_party/cpu/lib/CMakeLists.txt | 1 + .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 56 +++++ .../ConvertElementwiseOps.cpp | 3 +- .../TritonToTritonCPU/ConvertMemoryOps.cpp | 37 ++- .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 24 +- .../lib/TritonToTritonCPU/TypeConverter.cpp | 25 +- 10 files changed, 446 insertions(+), 38 deletions(-) create mode 100644 third_party/cpu/include/Analysis/TensorPtrShapeInfo.h create mode 100644 third_party/cpu/lib/Analysis/CMakeLists.txt create mode 100644 third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e5eb53caf686..acd31c07290f 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" diff --git a/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h new file mode 100644 index 000000000000..3023aa5bc0ee --- /dev/null +++ b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h @@ -0,0 +1,109 @@ +#ifndef TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H +#define TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton::cpu { + +// Lattice value to hold a shape and strides for a tensor pointer. +// If multiple size or stride values are possible for some dimension +// then ShapedType::kDynamic is used for that dimension. +class TensorPtrShapeInfo { +public: + TensorPtrShapeInfo() = default; + + TensorPtrShapeInfo(ArrayRef shape, ArrayRef strides) + : shape(shape), strides(strides) { + assert(shape.size() == strides.size()); + } + + ArrayRef getShape() const { return shape; } + ArrayRef getStrides() const { return strides; } + + int64_t getRank() const { return static_cast(shape.size()); } + int64_t getSize(int64_t dim) const { return shape[dim]; } + int64_t getStride(int64_t dim) const { return strides[dim]; } + + bool operator==(const TensorPtrShapeInfo &other) const { + return shape == other.shape && strides == other.strides; + } + + static TensorPtrShapeInfo join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs); + + static TensorPtrShapeInfo getPessimisticValueState(Value value); + + void print(raw_ostream &os) const { + os << "shape = ["; + llvm::interleaveComma(shape, os); + os << "], strides = ["; + llvm::interleaveComma(strides, os); + os << "]"; + } + +private: + SmallVector shape; + SmallVector strides; +}; + +using TensorPtrShapeInfoMapT = DenseMap; +class ModuleTensorPtrShapeInfoAnalysis + : public CallGraph { +public: + explicit ModuleTensorPtrShapeInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, TensorPtrShapeInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = + dyn_cast(callOp.resolveCallable(&symbolTable)); + update(callOp, callee); + }); + } + } + + TensorPtrShapeInfo *getPtrShapeInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton::cpu + +#endif // TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H diff --git a/third_party/cpu/lib/Analysis/CMakeLists.txt b/third_party/cpu/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000000..d0ac08b9daf0 --- /dev/null +++ b/third_party/cpu/lib/Analysis/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonCPUAnalysis + TensorPtrShapeInfo.cpp + + DEPENDS + TritonCPUTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + TritonIR + TritonCPUIR +) diff --git a/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp new file mode 100644 index 000000000000..cba2e3acd4cc --- /dev/null +++ b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp @@ -0,0 +1,217 @@ +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::cpu { + +TensorPtrShapeInfo TensorPtrShapeInfo::join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + assert(lhs.getRank() == rhs.getRank()); + + SmallVector shape(lhs.getShape()); + SmallVector strides(lhs.getStrides()); + for (int64_t i = 0; i < lhs.getRank(); ++i) { + if (shape[i] != rhs.getSize(i)) + shape[i] = ShapedType::kDynamic; + if (strides[i] != rhs.getStride(i)) + strides[i] = ShapedType::kDynamic; + } + return TensorPtrShapeInfo(shape, strides); +} + +namespace { + +template +void initPessimisticStateFromFunc(int argNumber, T funcOp, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + auto loadFromAttr = [&](std::string_view attrName, + SmallVectorImpl &out) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + out = SmallVector(vals.begin(), vals.end()); + } + }; + loadFromAttr("tt.shape", shape); + loadFromAttr("tt.strides", strides); +} + +TensorPtrShapeInfo getPessimisticValueState(Value value) { + int rank = 0; + if (triton::isTensorPointerType(value.getType())) + rank = cast(getPointeeType(value.getType())).getRank(); + + SmallVector shape; + SmallVector strides; + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state. + } else { + // Other operations are conservatively initialized with dynamic + // shape and strides unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.shape")) { + auto vals = cast(attr).getValues(); + shape = SmallVector(vals.begin(), vals.end()); + } else { + shape.insert(shape.end(), rank, ShapedType::kDynamic); + } + if (Attribute attr = op->getDiscardableAttr("tt.strides")) { + auto vals = cast(attr).getValues(); + strides = SmallVector(vals.begin(), vals.end()); + } else { + strides.insert(strides.end(), rank, ShapedType::kDynamic); + } + } + } + + return TensorPtrShapeInfo(shape, strides); +} + +class ShapeInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + void + setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join(getPessimisticValueState(lattice->getPoint()))); + } + +public: + ShapeInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncShapeInfoMapT = DenseMap; + + void visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +ShapeInfoAnalysis::ShapeInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>(solver) {} + +SmallVector copyConstOrDynamic(OperandRange ops) { + SmallVector res; + for (auto op : ops) { + if (auto cstOp = op.getDefiningOp()) { + auto intAttr = dyn_cast(cstOp.getValue()); + assert(intAttr); + res.push_back(intAttr.getInt()); + } else { + res.push_back(ShapedType::kDynamic); + } + } + return res; +} + +void ShapeInfoAnalysis::visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + + TensorPtrShapeInfo res; + // Tensor pointers are only produced by MakeTensorPtrOp which has + // shape and strides as its args, and AdvanceOp which preserves + // shape and strides of the input pointer. + if (auto makePtrOp = dyn_cast(op)) { + SmallVector shape = copyConstOrDynamic(makePtrOp.getShape()); + SmallVector strides = copyConstOrDynamic(makePtrOp.getStrides()); + res = TensorPtrShapeInfo(shape, strides); + } else if (auto advOp = dyn_cast(op)) { + res = operands[0]->getValue(); + } + + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(res)); +} + +} // namespace + +void ModuleTensorPtrShapeInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + ShapeInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *shapeInfoMap = getFuncData(funcOp); + auto updateShapeInfoMap = [&](Value value) { + auto shapeInfo = analysis->getLatticeElement(value)->getValue(); + TensorPtrShapeInfo curShapeInfo; + if (shapeInfoMap->count(value)) { + curShapeInfo = + TensorPtrShapeInfo::join(shapeInfo, shapeInfoMap->lookup(value)); + } else { + curShapeInfo = shapeInfo; + } + (*shapeInfoMap)[value] = curShapeInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateShapeInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateShapeInfoMap(value); + } + }); +} + +void ModuleTensorPtrShapeInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *shapeInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, ArrayRef value) { + SmallVector curValue(value); + if (auto attr = + callee.getArgAttrOfType(index, attrName)) { + auto oldValue = cast(attr).getValues(); + assert(oldValue.size() == curValue.size()); + for (size_t i = 0; i < curValue.size(); ++i) + if (curValue[i] != oldValue[i]) + curValue[i] = ShapedType::kDynamic; + } + auto attr = DenseElementsAttr::get( + VectorType::get(curValue.size(), + IntegerType::get(callee.getContext(), 64)), + ArrayRef(curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto shapeInfo = shapeInfoMap->lookup(value); + if (shapeInfo.getRank()) { + setAttrFn("tt.shape", shapeInfo.getShape()); + setAttrFn("tt.strides", shapeInfo.getStrides()); + } + } +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt index fc9a19e52b0d..1db64c58ec20 100644 --- a/third_party/cpu/lib/CMakeLists.txt +++ b/third_party/cpu/lib/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Analysis) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp index 7bd602dc81a7..68d7231039c5 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -252,6 +252,59 @@ struct IntToPtrOpConversion : public OpConversionPattern { } }; +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expect only scalar pointers here. + assert(isa(op.getType())); + auto ptrTy = cast(op.getPtr().getType()); + Type elemTy = getTypeConverter()->convertType(ptrTy.getPointeeType()); + Type resTy = getTypeConverter()->convertType(ptrTy); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + rewriter.replaceOpWithNewOp(op, resTy, elemTy, ptr, offset); + return success(); + } +}; + +struct PtrBitcastConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + assert(isa(op.getType())); + assert(isa(op.getSrc().getType())); + Value src = rewriter.getRemappedValue(op.getSrc()); + rewriter.replaceOp(op, src); + return success(); + } +}; + +struct PtrSelectConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + if (!isa(op.getType())) + return failure(); + + Value trueVal = rewriter.getRemappedValue(op.getTrueValue()); + Value falseVal = rewriter.getRemappedValue(op.getFalseValue()); + Value cond = rewriter.getRemappedValue(op.getCondition()); + rewriter.replaceOpWithNewOp(op, cond, trueVal, falseVal); + return success(); + } +}; + struct MemoryOpToLLVM : public triton::impl::MemoryOpToLLVMBase { using MemoryOpToLLVMBase::MemoryOpToLLVMBase; @@ -276,6 +329,9 @@ struct MemoryOpToLLVM patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index e5cd372fe3e4..3b87255a2300 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -49,7 +49,8 @@ class ElementwiseOpConversionTarget : public ConversionTarget { return converter.isLegal(op); }); - addIllegalOp(); + addDynamicallyLegalOp( + [](triton::BitcastOp op) { return isa(op.getType()); }); addIllegalOp(); addIllegalOp(); addIllegalOp(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 394289063f14..9a7797bb13ee 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -1,5 +1,6 @@ #include "TypeConverter.h" +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" #include "cpu/include/TritonToTritonCPU/Passes.h" #include "mlir/Analysis/DataFlowFramework.h" @@ -35,11 +36,13 @@ namespace { template struct MemoryOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getContext; MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context), - axisAnalysis(axisInfoAnalysis) {} + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) {} Value extractScalarPointer(Location loc, Value ptrs, ArrayRef indices, @@ -52,8 +55,28 @@ struct MemoryOpConversion : public OpConversionPattern { return ptr; } + Value extractMemRef(Location loc, Value ptr, + ConversionPatternRewriter &rewriter) const { + auto tensorTy = dyn_cast( + dyn_cast(ptr.getType()).getPointeeType()); + auto elemTy = tensorTy.getElementType(); + auto shapeInfo = shapeAnalysis.getPtrShapeInfo(ptr); + Type memRefTy; + if (shapeInfo && shapeInfo->getRank() > 0) { + auto layout = + StridedLayoutAttr::get(getContext(), 0, shapeInfo->getStrides()); + memRefTy = MemRefType::get(shapeInfo->getShape(), elemTy, layout); + } else { + SmallVector dynVals(tensorTy.getRank(), ShapedType::kDynamic); + auto layout = StridedLayoutAttr::get(getContext(), 0, dynVals); + memRefTy = MemRefType::get(dynVals, elemTy, layout); + } + return rewriter.create(loc, memRefTy, ptr); + } + protected: ModuleAxisInfoAnalysis &axisAnalysis; + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; }; struct LoadOpConversion : public MemoryOpConversion { @@ -80,7 +103,7 @@ struct LoadOpConversion : public MemoryOpConversion { llvm_unreachable("unsupported load op"); } - auto memRef = rewriter.getRemappedValue(ptr); + auto memRef = extractMemRef(loc, ptr, rewriter); auto rank = dyn_cast(memRef.getType()).getRank(); auto resTy = dyn_cast( getTypeConverter()->convertType(loadOp.getResult().getType())); @@ -250,7 +273,7 @@ struct StoreOpConversion : public MemoryOpConversion { } auto value = rewriter.getRemappedValue(storeOp.getValue()); - auto memRef = rewriter.getRemappedValue(ptr); + auto memRef = extractMemRef(loc, ptr, rewriter); auto rank = dyn_cast(memRef.getType()).getRank(); auto indices = rewriter.create(loc, ptr).getResults(); SmallVector inBounds(rank, true); @@ -400,12 +423,14 @@ struct ConvertMemoryOps ModuleOp mod = getOperation(); ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); MemoryOpConversionTarget convTarget(*context); TritonToTritonCPUTypeConverter pointerConverter; RewritePatternSet patterns(context); - patterns.add(axisInfoAnalysis, pointerConverter, context); - patterns.add(axisInfoAnalysis, pointerConverter, - context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp index ade8b858bbfb..82123c376dc1 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -47,12 +47,14 @@ class PtrConversionTarget : public ConversionTarget { addLegalDialect(); addLegalOp(); - // Allow only scalar pointer conversion. + // Scalar pointer operations are translated directly to LLVM. addDynamicallyLegalOp( [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); addDynamicallyLegalOp([](triton::IntToPtrOp op) { return op.getSrc().getType().isInteger(); }); + addDynamicallyLegalOp( + [](triton::AddPtrOp op) { return isa(op.getType()); }); } }; @@ -89,12 +91,9 @@ struct SplatOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); Value val = op.getSrc(); - Type dstValType = getTypeConverter()->convertType(val.getType()); // Cast pointer if (isa(val.getType())) - val = rewriter - .create( - loc, getTypeConverter()->convertType(val.getType()), val) + val = rewriter.create(loc, rewriter.getI64Type(), val) .getResult(); Type resType = getTypeConverter()->convertType(op.getType()); auto cast = rewriter.create(loc, resType, val); @@ -115,11 +114,16 @@ struct AddPtrOpConversion : public OpConversionPattern { Value offset = rewriter.getRemappedValue(op.getOffset()); unsigned offsetBitWidth = getElemBitWidth(offset.getType()); unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); - // Compute scale. i1 elements take 1 byte. - Value scale = rewriter.create( - loc, (elemBitWidth + 7) / 8, offsetBitWidth); - if (isa(offset.getType())) - scale = rewriter.create(loc, offset.getType(), scale); + // Scalar case is not expected. + assert(isa(offset.getType())); + assert(isa(ptr.getType())); + VectorType offsetTy = cast(offset.getType()); + // Build scale vector. i1 elements take 1 byte. + Value scale = rewriter.create( + loc, offsetTy, + SplatElementsAttr::get( + offsetTy, rewriter.getIntegerAttr(offsetTy.getElementType(), + (elemBitWidth + 7) / 8))); offset = rewriter.create(loc, offset, scale); offset = rewriter.create(loc, ptr.getType(), offset); rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp index 07b2da0468ba..ce66f8faeb3e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -8,44 +8,27 @@ using namespace mlir::triton::cpu; TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrTy) -> Type { - if (triton::isTensorPointerType(ptrTy)) { - // Tensor pointer is translated into a memref - auto tensorTy = dyn_cast(ptrTy.getPointeeType()); - auto elemTy = tensorTy.getElementType(); - // TODO: use dynamic strides - SmallVector shape(tensorTy.getRank(), ShapedType::kDynamic); - return MemRefType::get(shape, elemTy); - } - return IntegerType::get(ptrTy.getContext(), 64); - }); addConversion([this](RankedTensorType tensorTy) -> Type { Type elemTy = convertType(tensorTy.getElementType()); + if (isa(elemTy)) + elemTy = IntegerType::get(tensorTy.getContext(), 64); return VectorType::get(tensorTy.getShape(), elemTy); }); // Converted ops produce vectors instead of tensors. Provide conversion - // here for users. Also, convert pointers when required. + // here for users. addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> std::optional { - if (isa(type)) - return builder.create(loc, type, inputs); return builder.create(loc, type, inputs) .getResult(0); }); - // Converted loads and stores consume memrefs instead of pointers, use extract - // op to get them. Also, provide conversion for vector users and pointer - // casts. + // Provide conversion for vector users. addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> std::optional { - if (type.isInteger() && isa(inputs.front().getType())) - return builder.create(loc, type, inputs); if (isa(type)) return builder.create(loc, type, inputs) .getResult(0); - if (isa(type)) - return builder.create(loc, type, inputs); llvm_unreachable("Unexpected target materizalization"); }); } From 27a451cdec2a0f3d0fe88eb04c2c264e67bf452f Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 22 May 2024 15:18:41 -0700 Subject: [PATCH 11/13] Support 'other' arg for loads. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 2 + .../TritonToTritonCPU/ConvertMemoryOps.cpp | 40 ++++++++++--------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 839e57331744..8090ee40cf55 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3632,6 +3632,7 @@ def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) for dtype_str in torch_dtypes @@ -3670,6 +3671,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): torch.testing.assert_close(output, reference_out) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) @pytest.mark.parametrize("mask_val", [True, False]) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 9a7797bb13ee..2787247a731c 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -37,6 +37,7 @@ template struct MemoryOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpConversionPattern::getContext; + using OpConversionPattern::getTypeConverter; MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, @@ -74,6 +75,19 @@ struct MemoryOpConversion : public OpConversionPattern { return rewriter.create(loc, memRefTy, ptr); } + Value convertOtherVal(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + if (loadOp.getOther()) + return rewriter.getRemappedValue(loadOp.getOther()); + + auto resTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + return rewriter.create( + loadOp.getLoc(), resTy, + SplatElementsAttr::get(resTy, + rewriter.getZeroAttr(resTy.getElementType()))); + } + protected: ModuleAxisInfoAnalysis &axisAnalysis; ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; @@ -139,13 +153,12 @@ struct LoadOpConversion : public MemoryOpConversion { ? rewriter.getRemappedValue(loadOp.getMask()) : nullptr; Value zeroIdx = rewriter.create(loc, 0); - Value defaultVal = loadOp.getOther(); - if (!defaultVal) - defaultVal = rewriter.create( - loc, rewriter.getZeroAttr(vecTy.getElementType())); - Value res = rewriter.create(loc, vecTy, defaultVal); + Value defaultVal = convertOtherVal(loadOp, rewriter); + Value res = defaultVal; for (int64_t idx = 0; idx < numElems; idx += shape.back()) { auto indices = delinearize(idx, strides); + SmallVector subIndices(indices.begin(), + indices.begin() + indices.size() - 1); auto ptr = extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); Value memRef = @@ -153,13 +166,12 @@ struct LoadOpConversion : public MemoryOpConversion { Value vec; if (mask) { Value subMask = mask; + Value passThru = defaultVal; if (shape.size() > 1) { - SmallVector subIndices = indices; - subIndices.pop_back(); subMask = rewriter.create(loc, mask, subIndices); + passThru = + rewriter.create(loc, defaultVal, subIndices); } - Value passThru = - rewriter.create(loc, subVecTy, defaultVal); vec = rewriter.create( loc, subVecTy, memRef, zeroIdx, subMask, passThru); } else { @@ -167,8 +179,6 @@ struct LoadOpConversion : public MemoryOpConversion { } if (shape.size() > 1) { - SmallVector subIndices = indices; - subIndices.pop_back(); res = rewriter.create(loc, vec, res, subIndices); } else { res = vec; @@ -199,13 +209,7 @@ struct LoadOpConversion : public MemoryOpConversion { auto cache = loadOp.getCache(); auto evict = loadOp.getEvict(); auto isVolatile = loadOp.getIsVolatile(); - - Value defaultVal = loadOp.getOther(); - if (!defaultVal) - defaultVal = rewriter.create( - loc, rewriter.getZeroAttr(vecTy.getElementType())); - Value dst = rewriter.create(loc, vecTy, defaultVal); - + Value dst = convertOtherVal(loadOp, rewriter); int64_t numElems = vecTy.getNumElements(); auto strides = computeStrides(vecTy.getShape()); for (auto idx = 0; idx < numElems; ++idx) { From ddcbcde4b479a8d3774c152bdeae33d1ad827f50 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 22 May 2024 15:55:36 -0700 Subject: [PATCH 12/13] Support tl.join. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 4 ++++ .../ConvertElementwiseOps.cpp | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8090ee40cf55..633d8c4c49c6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1775,6 +1775,7 @@ def umulhi32(a, b): np.testing.assert_equal(z_ref, to_numpy(z_tri)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join(device): @@ -1795,6 +1796,7 @@ def kernel(X, Y, Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_scalars(device): @@ -1814,6 +1816,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([42, 100], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_with_mma(device): @@ -1852,6 +1855,7 @@ def kernel(Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_interleave_scalars(device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 3b87255a2300..218dd827619a 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -59,6 +59,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -217,6 +218,24 @@ struct TransOpConversion : public OpConversionPattern { } }; +struct JoinOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto interleave = rewriter.create(loc, lhs, rhs); + // JoinOp creates a new dimension, but InterleaveOp doubles the final one. + // Use ShapeCastOp to get the required shape. + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, interleave); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -300,6 +319,7 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 10b92e40c76f40b0809ecb6c3860f6db41a4fd2c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 23 May 2024 08:33:20 -0700 Subject: [PATCH 13/13] Minor renaming. Signed-off-by: Ilya Enkovich --- third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp index b6fbb1893202..51a5f42fa63a 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -30,9 +30,9 @@ using namespace mlir::triton::cpu; namespace { -class PtrConversionTarget : public ConversionTarget { +class DotConversionTarget : public ConversionTarget { public: - explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + explicit DotConversionTarget(MLIRContext &ctx, TypeConverter &converter) : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); @@ -78,7 +78,7 @@ struct ConvertDotOp : public triton::impl::ConvertDotOpBase { ModuleOp mod = getOperation(); TritonToTritonCPUTypeConverter typeConverter; - PtrConversionTarget convTarget(*context, typeConverter); + DotConversionTarget convTarget(*context, typeConverter); RewritePatternSet patterns(context); patterns.add(typeConverter, context);