3232#include " SCEV/ScalarEvolution.h"
3333#include " SCEV/ScalarEvolutionExpander.h"
3434
35+ #include " llvm/Analysis/DependenceAnalysis.h"
3536#include < deque>
3637
3738#include " llvm/IR/BasicBlock.h"
@@ -92,6 +93,7 @@ bool is_load_uncacheable(
9293struct CacheAnalysis {
9394 AAResults &AA;
9495 Function *oldFunc;
96+ ScalarEvolution &SE;
9597 LoopInfo &OrigLI;
9698 DominatorTree &DT;
9799 TargetLibraryInfo &TLI;
@@ -100,11 +102,11 @@ struct CacheAnalysis {
100102 bool topLevel;
101103 std::map<Value *, bool > seen;
102104 CacheAnalysis (
103- AAResults &AA, Function *oldFunc, LoopInfo &OrigLI, DominatorTree &OrigDT ,
104- TargetLibraryInfo &TLI,
105+ AAResults &AA, Function *oldFunc, ScalarEvolution &SE, LoopInfo &OrigLI ,
106+ DominatorTree &OrigDT, TargetLibraryInfo &TLI,
105107 const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
106108 const std::map<Argument *, bool > &uncacheable_args, bool topLevel)
107- : AA(AA), oldFunc(oldFunc), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
109+ : AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), DT(OrigDT), TLI(TLI),
108110 unnecessaryInstructions (unnecessaryInstructions),
109111 uncacheable_args(uncacheable_args), topLevel(topLevel) {}
110112
@@ -252,6 +254,123 @@ struct CacheAnalysis {
252254 if (!writesToMemoryReadBy (AA, &li, inst2)) {
253255 return false ;
254256 }
257+
258+ if (auto SI = dyn_cast<StoreInst>(inst2)) {
259+
260+ const SCEV *LS = SE.getSCEV (li.getPointerOperand ());
261+ const SCEV *SS = SE.getSCEV (SI->getPointerOperand ());
262+ if (SS != SE.getCouldNotCompute ()) {
263+
264+ // llvm::errs() << *inst2 << " - " << li << "\n";
265+ // llvm::errs() << *SS << " - " << *LS << "\n";
266+ const auto &DL = li.getModule ()->getDataLayout ();
267+
268+ #if LLVM_VERSION_MAJOR >= 10
269+ auto TS = SE.getConstant (
270+ APInt (64 , DL.getTypeStoreSize (li.getType ()).getFixedSize ()));
271+ #else
272+ auto TS = SE.getConstant (
273+ APInt (64 , DL.getTypeStoreSize (li.getType ())));
274+ #endif
275+ for (auto lim = LS; lim != SE.getCouldNotCompute ();) {
276+ // [start load, L+Size] [S, S+Size]
277+ for (auto slim = SS; slim != SE.getCouldNotCompute ();) {
278+ auto lsub = SE.getMinusSCEV (slim, SE.getAddExpr (lim, TS));
279+ // llvm::errs() << " *** " << *lsub << "|" << *slim << "|" <<
280+ // *lim << "\n";
281+ if (SE.isKnownNonNegative (lsub)) {
282+ return false ;
283+ }
284+ if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
285+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
286+ slim = arL->getStart ();
287+ continue ;
288+ } else if (SE.isKnownNonPositive (
289+ arL->getStepRecurrence (SE))) {
290+ #if LLVM_VERSION_MAJOR >= 12
291+ auto bd =
292+ SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
293+ #else
294+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
295+ #endif
296+ if (bd == SE.getCouldNotCompute ())
297+ break ;
298+ slim = arL->evaluateAtIteration (bd, SE);
299+ continue ;
300+ }
301+ }
302+ break ;
303+ }
304+
305+ if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
306+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
307+ #if LLVM_VERSION_MAJOR >= 12
308+ auto bd = SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
309+ #else
310+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
311+ #endif
312+ if (bd == SE.getCouldNotCompute ())
313+ break ;
314+ lim = arL->evaluateAtIteration (bd, SE);
315+ continue ;
316+ } else if (SE.isKnownNonPositive (arL->getStepRecurrence (SE))) {
317+ lim = arL->getStart ();
318+ continue ;
319+ }
320+ }
321+ break ;
322+ }
323+ for (auto lim = LS; lim != SE.getCouldNotCompute ();) {
324+ // [S, S+Size][start load, L+Size]
325+ for (auto slim = SS; slim != SE.getCouldNotCompute ();) {
326+ auto lsub = SE.getMinusSCEV (lim, SE.getAddExpr (slim, TS));
327+ // llvm::errs() << " $$$ " << *lsub << "|" << *slim << "|" <<
328+ // *lim << "\n";
329+ if (SE.isKnownNonNegative (lsub)) {
330+ return false ;
331+ }
332+ if (auto arL = dyn_cast<SCEVAddRecExpr>(slim)) {
333+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
334+ #if LLVM_VERSION_MAJOR >= 12
335+ auto bd =
336+ SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
337+ #else
338+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
339+ #endif
340+ if (bd == SE.getCouldNotCompute ())
341+ break ;
342+ slim = arL->evaluateAtIteration (bd, SE);
343+ continue ;
344+ } else if (SE.isKnownNonPositive (
345+ arL->getStepRecurrence (SE))) {
346+ slim = arL->getStart ();
347+ continue ;
348+ }
349+ }
350+ break ;
351+ }
352+
353+ if (auto arL = dyn_cast<SCEVAddRecExpr>(lim)) {
354+ if (SE.isKnownNonNegative (arL->getStepRecurrence (SE))) {
355+ lim = arL->getStart ();
356+ continue ;
357+ } else if (SE.isKnownNonPositive (arL->getStepRecurrence (SE))) {
358+ #if LLVM_VERSION_MAJOR >= 12
359+ auto bd = SE.getSymbolicMaxBackedgeTakenCount (arL->getLoop ());
360+ #else
361+ auto bd = SE.getBackedgeTakenCount (arL->getLoop ());
362+ #endif
363+ if (bd == SE.getCouldNotCompute ())
364+ break ;
365+ lim = arL->evaluateAtIteration (bd, SE);
366+ continue ;
367+ }
368+ }
369+ break ;
370+ }
371+ }
372+ }
373+
255374 if (auto II = dyn_cast<IntrinsicInst>(inst2)) {
256375 if (II->getIntrinsicID () == Intrinsic::nvvm_barrier0 ||
257376 II->getIntrinsicID () == Intrinsic::amdgcn_s_barrier) {
@@ -1309,9 +1428,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
13091428 for (auto &I : *BB)
13101429 unnecessaryInstructionsTmp.insert (&I);
13111430 }
1312- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc , gutils->OrigLI ,
1313- gutils->OrigDT , TLI, unnecessaryInstructionsTmp,
1314- _uncacheable_argsPP,
1431+ CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
1432+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
1433+ gutils->OrigLI , gutils->OrigDT , TLI,
1434+ unnecessaryInstructionsTmp, _uncacheable_argsPP,
13151435 /* topLevel*/ false );
13161436 const std::map<CallInst *, const std::map<Argument *, bool >>
13171437 uncacheable_args_map = CA.compute_uncacheable_args_for_callsites ();
@@ -2434,9 +2554,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
24342554 for (auto &I : *BB)
24352555 unnecessaryInstructionsTmp.insert (&I);
24362556 }
2437- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc , gutils->OrigLI ,
2438- gutils->OrigDT , TLI, unnecessaryInstructionsTmp,
2439- _uncacheable_argsPP, topLevel);
2557+ CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
2558+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
2559+ gutils->OrigLI , gutils->OrigDT , TLI,
2560+ unnecessaryInstructionsTmp, _uncacheable_argsPP, topLevel);
24402561 const std::map<CallInst *, const std::map<Argument *, bool >>
24412562 uncacheable_args_map =
24422563 (augmenteddata) ? augmenteddata->uncacheable_args_map
0 commit comments