@@ -778,6 +778,56 @@ class AdjointGenerator
778778 }
779779
780780 void visitAtomicRMWInst (llvm::AtomicRMWInst &I) {
781+ if (Mode == DerivativeMode::ForwardMode) {
782+ IRBuilder<> BuilderZ (&I);
783+ getForwardBuilder (BuilderZ);
784+ switch (I.getOperation ()) {
785+ case AtomicRMWInst::FAdd:
786+ case AtomicRMWInst::FSub: {
787+ auto rule = [&](Value *ptr, Value *dif) -> Value * {
788+ if (!gutils->isConstantInstruction (&I)) {
789+ assert (ptr);
790+ AtomicRMWInst *rmw = nullptr ;
791+ #if LLVM_VERSION_MAJOR >= 13
792+ rmw = BuilderZ.CreateAtomicRMW (I.getOperation (), ptr, dif,
793+ I.getAlign (), I.getOrdering (),
794+ I.getSyncScopeID ());
795+ #elif LLVM_VERSION_MAJOR >= 11
796+ rmw = BuilderZ.CreateAtomicRMW (I.getOperation (), ptr, dif,
797+ I.getOrdering (), I.getSyncScopeID ());
798+ rmw->setAlignment (I.getAlign ());
799+ #else
800+ rmw = BuilderZ.CreateAtomicRMW (
801+ I.getOperation (), ptr, dif, I.getOrdering (),
802+ I.getSyncScopeID ());
803+ #endif
804+ rmw->setVolatile (I.isVolatile ());
805+ if (gutils->isConstantValue (&I))
806+ return Constant::getNullValue (dif->getType ());
807+ else
808+ return rmw;
809+ } else {
810+ assert (gutils->isConstantValue (&I));
811+ return Constant::getNullValue (dif->getType ());
812+ }
813+ };
814+
815+ Value *diff = applyChainRule (
816+ I.getType (), BuilderZ, rule,
817+ gutils->isConstantValue (I.getPointerOperand ())
818+ ? nullptr
819+ : gutils->invertPointerM (I.getPointerOperand (), BuilderZ),
820+ gutils->isConstantValue (I.getValOperand ())
821+ ? Constant::getNullValue (I.getType ())
822+ : gutils->invertPointerM (I.getValOperand (), BuilderZ));
823+ if (!gutils->isConstantValue (&I))
824+ setDiffe (&I, diff, BuilderZ);
825+ return ;
826+ }
827+ default :
828+ break ;
829+ }
830+ }
781831 if (!gutils->isConstantInstruction (&I) || !gutils->isConstantValue (&I)) {
782832 TR.dump ();
783833 llvm::errs () << " oldFunc: " << *gutils->newFunc << " \n " ;
@@ -11083,7 +11133,8 @@ class AdjointGenerator
1108311133 auto rule = [&args](Value *tofree) { args.push_back (tofree); };
1108411134 applyChainRule (Builder2, rule, tofree);
1108511135
11086- Builder2.CreateCall (free->getFunctionType (), free, args);
11136+ auto frees = Builder2.CreateCall (free->getFunctionType (), free, args);
11137+ frees->setDebugLoc (gutils->getNewFromOriginal (orig->getDebugLoc ()));
1108711138
1108811139 return ;
1108911140 }
0 commit comments