diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 2b2f6afeb5ce..9da8e5628667 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -79,3 +79,20 @@ llvm_update_compile_flags(triton-translate) MLIRROCDLToLLVMIRTranslation ) mlir_check_all_link_libraries(triton-translate) + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000000..fe82a1dce28e --- /dev/null +++ b/bin/triton-llvm-opt.cpp @@ -0,0 +1,114 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + std::string OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + return 0; +} diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index fbaefe68375c..9c0a6c26eea9 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_translation_library(TritonLLVMIR LLVMIRTranslation.cpp LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp LINK_COMPONENTS Core diff --git a/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000000..44afcfd21109 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 6d64fcbb1a29..3acc6a92e09c 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,7 +1,8 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" - +#include "LLVMPasses.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -26,11 +27,20 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #include @@ -42,6 +52,89 @@ namespace fs = std::filesystem; +namespace { +using namespace llvm; + +static std::optional mapToLevel(unsigned optLevel, + unsigned sizeLevel) { + switch (optLevel) { + case 0: + return OptimizationLevel::O0; + + case 1: + return OptimizationLevel::O1; + + case 2: + switch (sizeLevel) { + case 0: + return OptimizationLevel::O2; + + case 1: + return OptimizationLevel::Os; + + case 2: + return OptimizationLevel::Oz; + } + break; + case 3: + return OptimizationLevel::O3; + } + return std::nullopt; +} + +// Create and return a lambda that uses LLVM pass manager builder to set up +// optimizations based on the given level. +static std::function +makeOptimizingPipeline(unsigned optLevel, unsigned sizeLevel, + TargetMachine *targetMachine) { + return [optLevel, sizeLevel, targetMachine](Module *m) -> Error { + std::optional ol = mapToLevel(optLevel, sizeLevel); + if (!ol) { + return make_error( + formatv("invalid optimization/size level {0}/{1}", optLevel, + sizeLevel) + .str(), + inconvertibleErrorCode()); + } + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. This + // cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also applies + // some scheduling that helps performance in some cases. We should work on + // using NVPTX target instead and address the performance regressions with + // some scheduling solution. + tuningOptions.SLPVectorization = true; + + PassBuilder pb(targetMachine, tuningOptions); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make sure all + // the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.addPass(pb.buildPerModuleDefaultPipeline(*ol)); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + namespace mlir { namespace triton { @@ -308,7 +401,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, return nullptr; } - auto optPipeline = mlir::makeOptimizingTransformer( + auto optPipeline = makeOptimizingPipeline( /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); diff --git a/lib/Target/LLVMIR/LLVMPasses.h b/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000000..1dcdb2992c02 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/test/LLVMIR/break-phi-struct.ll b/test/LLVMIR/break-phi-struct.ll new file mode 100644 index 000000000000..b27c87588f63 --- /dev/null +++ b/test/LLVMIR/break-phi-struct.ll @@ -0,0 +1,33 @@ +; RUN: triton-llvm-opt -break-struct-phi-nodes %s | FileCheck %s + +; CHECK-LABEL: struct +define {i32, i32} @struct(i1 %c) { +; CHECK: br i1 %{{.*}}, label [[TRUE:%.*]], label [[FALSE:%.*]] + br i1 %c, label %true, label %false + +true: + %s.1 = insertvalue {i32, i32} undef, i32 20, 0 + %s.2 = insertvalue {i32, i32} %s.1, i32 200, 1 + +; CHECK-DAG: [[E0:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E1:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +false: + %s.3 = insertvalue {i32, i32} undef, i32 30, 0 + %s.4 = insertvalue {i32, i32} %s.3, i32 300, 1 +; CHECK-DAG: [[E2:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E3:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +exit: +; CHECK-DAG: [[PHI0:%.*]] = phi i32 [ [[E0]], [[TRUE]] ], [ [[E2]], [[FALSE]] ] +; CHECK-DAG: [[PHI1:%.*]] = phi i32 [ [[E1]], [[TRUE]] ], [ [[E3]], [[FALSE]] ] +; CHECK: [[S0:%.*]] = insertvalue { i32, i32 } undef, i32 [[PHI0]], 0 +; CHECK: [[S1:%.*]] = insertvalue { i32, i32 } [[S0]], i32 [[PHI1]], 1 +; CHECK: ret { i32, i32 } [[S1]] + %r = phi {i32, i32} [ %s.2, %true], [ %s.4, %false ] + ret {i32, i32} %r +} diff --git a/test/lit.cfg.py b/test/lit.cfg.py index db65d3e4f172..5ea9c458dcd0 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -19,7 +19,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir'] +config.suffixes = ['.mlir', '.ll'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -62,6 +62,7 @@ llvm_config.with_environment('PATH', d, append_path=True) tools = [ 'triton-opt', + 'triton-llvm-opt', ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ]