Skip to content

Commit

Permalink
Faster activity datastructures and handle version breaking flang (rus…
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 30, 2022
1 parent f379fc4 commit 318e268
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ message("LLVM_INCLUDE_DIRS: ${LLVM_INCLUDE_DIRS}")
message("found llvm definitions " ${LLVM_DEFINITIONS})
message("found llvm version " ${LLVM_VERSION_MAJOR})

option(ENZYME_FLANG "Build for non-version compliant FLANG" OFF)
add_definitions(-DFLANG=1)

# Offer the user the choice of overriding the installation directories
set(INSTALL_INCLUDE_DIR include CACHE PATH "Installation directory for header files")
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/ActivityAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ class ActivityAnalyzer {
bool isConstantValue(TypeResults &TR, llvm::Value *val);

private:
std::map<llvm::Instruction *, std::set<llvm::Value *>>
llvm::DenseMap<llvm::Instruction *, llvm::SmallPtrSet<llvm::Value *, 4>>
ReEvaluateValueIfInactiveInst;
std::map<llvm::Value *, std::set<llvm::Value *>>
llvm::DenseMap<llvm::Value *, llvm::SmallPtrSet<llvm::Value *, 4>>
ReEvaluateValueIfInactiveValue;

std::map<llvm::Value *, std::set<llvm::Instruction *>>
llvm::DenseMap<llvm::Value *, llvm::SmallPtrSet<llvm::Instruction *, 4>>
ReEvaluateInstIfInactiveValue;

void InsertConstantInstruction(TypeResults &TR, llvm::Instruction *I);
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,11 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
#if LLVM_VERSION_MAJOR >= 14
malloccall->addDereferenceableRetAttr(
ci->getLimitedValue() * byteSizeOfType->getLimitedValue());
#ifndef FLANG
AttrBuilder B(ci->getContext());
#else
AttrBuilder B;
#endif
B.addDereferenceableOrNullAttr(ci->getLimitedValue() *
byteSizeOfType->getLimitedValue());
malloccall->setAttributes(
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ class Enzyme : public ModulePass {
// code left here to re-enable upon Attributor patch
Logic.PPC.FAM.clear(F, F.getName());

#if LLVM_VERSION_MAJOR >= 13
#if LLVM_VERSION_MAJOR >= 13 && !defined(FLANG)

AnalysisGetter AG(Logic.PPC.FAM);
SetVector<Function *> Functions;
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2180,7 +2180,11 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
}
#if LLVM_VERSION_MAJOR >= 14
malloccall->addDereferenceableRetAttr(size->getLimitedValue());
#ifndef FLANG
AttrBuilder B(malloccall->getContext());
#else
AttrBuilder B;
#endif
B.addDereferenceableOrNullAttr(size->getLimitedValue());
malloccall->setAttributes(malloccall->getAttributes().addRetAttributes(
malloccall->getContext(), B));
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
}

{
#if LLVM_VERSION_MAJOR >= 14
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
auto PA = SROAPass().run(*NewF, FAM);
#else
auto PA = SROA().run(*NewF, FAM);
Expand All @@ -1444,7 +1444,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
ReplaceReallocs(NewF);

{
#if LLVM_VERSION_MAJOR >= 14
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
auto PA = SROAPass().run(*NewF, FAM);
#else
auto PA = SROA().run(*NewF, FAM);
Expand Down Expand Up @@ -1973,12 +1973,12 @@ void SelectOptimization(Function *F) {
}
void PreProcessCache::optimizeIntermediate(Function *F) {
PromotePass().run(*F, FAM);
#if LLVM_VERSION_MAJOR >= 14
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
GVNPass().run(*F, FAM);
#else
GVN().run(*F, FAM);
#endif
#if LLVM_VERSION_MAJOR >= 14
#if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG)
SROAPass().run(*F, FAM);
#else
SROA().run(*F, FAM);
Expand Down
4 changes: 4 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,11 @@ class GradientUtils : public CacheUtility {
#if LLVM_VERSION_MAJOR >= 14
cast<CallInst>(anti)->addDereferenceableRetAttr(ci->getLimitedValue());
cal->addDereferenceableRetAttr(ci->getLimitedValue());
#ifndef FLANG
AttrBuilder B(Fn->getContext());
#else
AttrBuilder B;
#endif
B.addDereferenceableOrNullAttr(ci->getLimitedValue());
cast<CallInst>(anti)->setAttributes(
cast<CallInst>(anti)->getAttributes().addRetAttributes(
Expand Down

0 comments on commit 318e268

Please sign in to comment.