Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,20 @@ llvm_update_compile_flags(triton-translate)
MLIRROCDLToLLVMIRTranslation
)
mlir_check_all_link_libraries(triton-translate)

add_llvm_executable(triton-llvm-opt
triton-llvm-opt.cpp

DEPENDS
intrinsics_gen
SUPPORT_PLUGINS
)
target_link_libraries(triton-llvm-opt PRIVATE
TritonLLVMIR

LLVMCore
LLVMSupport
LLVMOption
LLVMCodeGen
)
export_executable_symbols_for_plugins(triton-llvm-opt)
114 changes: 114 additions & 0 deletions bin/triton-llvm-opt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
/// passes.
#include "lib/Target/LLVMIR/LLVMPasses.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SystemUtils.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/TargetParser/Triple.h"
#include <optional>

using namespace llvm;

static cl::opt<std::string> InputFilename(cl::Positional,
cl::desc("<input bitcode file>"),
cl::init("-"),
cl::value_desc("filename"));

static cl::opt<std::string> ClDataLayout("data-layout",
cl::desc("data layout string to use"),
cl::value_desc("layout-string"),
cl::init(""));
static cl::opt<std::string>
TargetTriple("mtriple", cl::desc("Override target triple for module"));

static cl::opt<bool>
BreakStructPhiNodes("break-struct-phi-nodes",
llvm::cl::desc("run pass to break phi struct"),
cl::init(false));

namespace {
static std::function<Error(Module *)> makeOptimizingPipeline() {
return [](Module *m) -> Error {
PipelineTuningOptions tuningOptions;
PassBuilder pb(nullptr, tuningOptions);

LoopAnalysisManager lam;
FunctionAnalysisManager fam;
CGSCCAnalysisManager cgam;
ModuleAnalysisManager mam;
pb.registerModuleAnalyses(mam);
pb.registerCGSCCAnalyses(cgam);
pb.registerFunctionAnalyses(fam);
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);

ModulePassManager mpm;
llvm::FunctionPassManager fpm;
if (BreakStructPhiNodes)
fpm.addPass(BreakStructPhiNodesPass());
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
mpm.run(*m, mam);
return Error::success();
};
}
} // namespace

int main(int argc, char **argv) {
InitLLVM X(argc, argv);
cl::ParseCommandLineOptions(
argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n");

LLVMContext Context;
SMDiagnostic Err;

// Load the input module...
auto SetDataLayout = [](StringRef, StringRef) -> std::optional<std::string> {
if (ClDataLayout.empty())
return std::nullopt;
return ClDataLayout;
};
std::unique_ptr<Module> M;
M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout));
if (!M) {
Err.print(argv[0], errs());
return 1;
}
// If we are supposed to override the target triple or data layout, do so now.
if (!TargetTriple.empty())
M->setTargetTriple(Triple::normalize(TargetTriple));
auto optPipeline = makeOptimizingPipeline();
if (auto err = optPipeline(M.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
}

if (verifyModule(*M, &errs())) {
errs() << argv[0] << ": " << InputFilename
<< ": error: input module is broken!\n";
return 1;
}

// Write to standard output.
std::unique_ptr<ToolOutputFile> Out;
std::string OutputFilename = "-";
std::error_code EC;
sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF;
Out.reset(new ToolOutputFile(OutputFilename, EC, Flags));
if (EC) {
errs() << EC.message() << '\n';
return 1;
}
Out->os() << *M << "\n";
return 0;
}
1 change: 1 addition & 0 deletions lib/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_translation_library(TritonLLVMIR
LLVMIRTranslation.cpp
LLVMDIScope.cpp
LLVMIRBreakPhiStruct.cpp

LINK_COMPONENTS
Core
Expand Down
60 changes: 60 additions & 0 deletions lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===----------------------------------------------------------------------===//
/// Implements a trivial pass breaking up 1 level deep structure in phi nodes.
/// This handles the common case generated by Triton and allow better
/// optimizations down the compiler pipeline.
//===----------------------------------------------------------------------===//
#include "LLVMPasses.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"

using namespace llvm;

static bool processPhiStruct(PHINode *phiNode) {
StructType *STy = dyn_cast<StructType>(phiNode->getType());
if (!STy)
return false;
IRBuilder<> builder(phiNode);
unsigned numOperands = phiNode->getNumIncomingValues();
unsigned numScalarEl = STy->getNumElements();
Value *newStruct = UndefValue::get(STy);
builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI());
llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP();
for (unsigned i = 0; i < numScalarEl; i++) {
builder.SetInsertPoint(phiNode);
PHINode *newPhiNode =
builder.CreatePHI(STy->getElementType(i), numOperands);
for (unsigned j = 0; j < numOperands; ++j) {
Value *operand = phiNode->getIncomingValue(j);
builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator());
newPhiNode->addIncoming(builder.CreateExtractValue(operand, i),
phiNode->getIncomingBlock(j));
}
builder.restoreIP(insertInsertPt);
newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i);
insertInsertPt = builder.saveIP();
}
phiNode->replaceAllUsesWith(newStruct);
return true;
}

static bool runOnFunction(Function &F) {
bool Changed = false;
SmallVector<PHINode *> PhiNodes;
for (BasicBlock &BB : F) {
for (Instruction &inst : BB) {
if (PHINode *phiNode = dyn_cast<PHINode>(&inst)) {
Changed |= processPhiStruct(phiNode);
continue;
}
break;
}
}
return Changed;
}

PreservedAnalyses BreakStructPhiNodesPass::run(Function &F,
FunctionAnalysisManager &AM) {

bool b = runOnFunction(F);
return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
97 changes: 95 additions & 2 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"

#include "LLVMPasses.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
Expand All @@ -26,11 +27,20 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include <optional>
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
Expand All @@ -42,6 +52,89 @@

namespace fs = std::filesystem;

namespace {
using namespace llvm;

static std::optional<OptimizationLevel> mapToLevel(unsigned optLevel,
unsigned sizeLevel) {
switch (optLevel) {
case 0:
return OptimizationLevel::O0;

case 1:
return OptimizationLevel::O1;

case 2:
switch (sizeLevel) {
case 0:
return OptimizationLevel::O2;

case 1:
return OptimizationLevel::Os;

case 2:
return OptimizationLevel::Oz;
}
break;
case 3:
return OptimizationLevel::O3;
}
return std::nullopt;
}

// Create and return a lambda that uses LLVM pass manager builder to set up
// optimizations based on the given level.
static std::function<Error(Module *)>
makeOptimizingPipeline(unsigned optLevel, unsigned sizeLevel,
TargetMachine *targetMachine) {
return [optLevel, sizeLevel, targetMachine](Module *m) -> Error {
std::optional<OptimizationLevel> ol = mapToLevel(optLevel, sizeLevel);
if (!ol) {
return make_error<StringError>(
formatv("invalid optimization/size level {0}/{1}", optLevel,
sizeLevel)
.str(),
inconvertibleErrorCode());
}
LoopAnalysisManager lam;
FunctionAnalysisManager fam;
CGSCCAnalysisManager cgam;
ModuleAnalysisManager mam;

PipelineTuningOptions tuningOptions;
tuningOptions.LoopUnrolling = true;
tuningOptions.LoopInterleaving = true;
tuningOptions.LoopVectorization = true;
// TODO: currently we run SLP vectorizer with an empty target machine. This
// cause the vectorizer to create larger vector which could be bad.
// Disabling it would currently cause regressions as this pass also applies
// some scheduling that helps performance in some cases. We should work on
// using NVPTX target instead and address the performance regressions with
// some scheduling solution.
tuningOptions.SLPVectorization = true;

PassBuilder pb(targetMachine, tuningOptions);

pb.registerModuleAnalyses(mam);
pb.registerCGSCCAnalyses(cgam);
pb.registerFunctionAnalyses(fam);
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);

ModulePassManager mpm;
llvm::FunctionPassManager fpm;
// Triton generates large structure of scalars which may pessimise
// optimizations, we run a pass to break up phi of struct to make sure all
// the struct are removed for the following passes.
fpm.addPass(BreakStructPhiNodesPass());
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
mpm.addPass(pb.buildPerModuleDefaultPipeline(*ol));
mpm.run(*m, mam);
return Error::success();
};
}
} // namespace

namespace mlir {
namespace triton {

Expand Down Expand Up @@ -308,7 +401,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module,
return nullptr;
}

auto optPipeline = mlir::makeOptimizingTransformer(
auto optPipeline = makeOptimizingPipeline(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);

Expand Down
16 changes: 16 additions & 0 deletions lib/Target/LLVMIR/LLVMPasses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"

namespace llvm {

// Pass to pre-process LLVM IR before optimization and break up phi of struct.
// Breaking up those phis into elementary types allows better optimizations
// downstream.
struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);

static StringRef name() { return "BreakStructPhiNodesPass"; }
};

} // namespace llvm
33 changes: 33 additions & 0 deletions test/LLVMIR/break-phi-struct.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; RUN: triton-llvm-opt -break-struct-phi-nodes %s | FileCheck %s

; CHECK-LABEL: struct
define {i32, i32} @struct(i1 %c) {
; CHECK: br i1 %{{.*}}, label [[TRUE:%.*]], label [[FALSE:%.*]]
br i1 %c, label %true, label %false

true:
%s.1 = insertvalue {i32, i32} undef, i32 20, 0
%s.2 = insertvalue {i32, i32} %s.1, i32 200, 1

; CHECK-DAG: [[E0:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0
; CHECK-DAG: [[E1:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1
; CHECK: br
br label %exit

false:
%s.3 = insertvalue {i32, i32} undef, i32 30, 0
%s.4 = insertvalue {i32, i32} %s.3, i32 300, 1
; CHECK-DAG: [[E2:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0
; CHECK-DAG: [[E3:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1
; CHECK: br
br label %exit

exit:
; CHECK-DAG: [[PHI0:%.*]] = phi i32 [ [[E0]], [[TRUE]] ], [ [[E2]], [[FALSE]] ]
; CHECK-DAG: [[PHI1:%.*]] = phi i32 [ [[E1]], [[TRUE]] ], [ [[E3]], [[FALSE]] ]
; CHECK: [[S0:%.*]] = insertvalue { i32, i32 } undef, i32 [[PHI0]], 0
; CHECK: [[S1:%.*]] = insertvalue { i32, i32 } [[S0]], i32 [[PHI1]], 1
; CHECK: ret { i32, i32 } [[S1]]
%r = phi {i32, i32} [ %s.2, %true], [ %s.4, %false ]
ret {i32, i32} %r
}
Loading