@@ -1166,45 +1166,76 @@ class AdjointGenerator
11661166 eraseIfUnused (IEI);
11671167 if (gutils->isConstantInstruction (&IEI))
11681168 return ;
1169- if (Mode == DerivativeMode::ReverseModePrimal)
1170- return ;
11711169
1172- IRBuilder<> Builder2 (IEI.getParent ());
1173- getReverseBuilder (Builder2);
1170+ switch (Mode) {
1171+ case DerivativeMode::ForwardMode: {
1172+ IRBuilder<> Builder2 (&IEI);
1173+ getForwardBuilder (Builder2);
11741174
1175- Value *dif1 = diffe (&IEI, Builder2);
1175+ Value *orig_vector = IEI.getOperand (0 );
1176+ Value *orig_inserted = IEI.getOperand (1 );
1177+ Value *orig_index = IEI.getOperand (2 );
11761178
1177- Value *orig_op0 = IEI.getOperand (0 );
1178- Value *orig_op1 = IEI.getOperand (1 );
1179- Value *op1 = gutils->getNewFromOriginal (orig_op1);
1180- Value *op2 = gutils->getNewFromOriginal (IEI.getOperand (2 ));
1179+ Value *diff_inserted = gutils->isConstantValue (orig_inserted)
1180+ ? ConstantFP::get (orig_inserted->getType (), 0 )
1181+ : diffe (orig_inserted, Builder2);
11811182
1182- size_t size0 = 1 ;
1183- if (orig_op0->getType ()->isSized ())
1184- size0 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1185- orig_op0->getType ()) +
1186- 7 ) /
1187- 8 ;
1188- size_t size1 = 1 ;
1189- if (orig_op1->getType ()->isSized ())
1190- size1 = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1191- orig_op1->getType ()) +
1192- 7 ) /
1193- 8 ;
1183+ Value *prediff =
1184+ gutils->isConstantValue (orig_vector)
1185+ ? diffe (orig_vector, Builder2)
1186+ : ConstantVector::getNullValue (orig_vector->getType ());
1187+ auto dindex = Builder2.CreateInsertElement (
1188+ prediff, diff_inserted, gutils->getNewFromOriginal (orig_index));
1189+ setDiffe (&IEI, dindex, Builder2);
11941190
1195- if (!gutils-> isConstantValue (orig_op0))
1196- addToDiffe (orig_op0,
1197- Builder2. CreateInsertElement (
1198- dif1, Constant::getNullValue (op1-> getType ()),
1199- lookup (op2, Builder2)),
1200- Builder2, TR. addingType (size0, orig_op0) );
1191+ return ;
1192+ }
1193+ case DerivativeMode::ReverseModeGradient:
1194+ case DerivativeMode::ReverseModeCombined: {
1195+ IRBuilder<> Builder2 (IEI. getParent ());
1196+ getReverseBuilder (Builder2 );
12011197
1202- if (!gutils->isConstantValue (orig_op1))
1203- addToDiffe (orig_op1,
1204- Builder2.CreateExtractElement (dif1, lookup (op2, Builder2)),
1205- Builder2, TR.addingType (size1, orig_op1));
1198+ Value *dif1 = diffe (&IEI, Builder2);
1199+
1200+ Value *orig_op0 = IEI.getOperand (0 );
1201+ Value *orig_op1 = IEI.getOperand (1 );
1202+ Value *op1 = gutils->getNewFromOriginal (orig_op1);
1203+ Value *op2 = gutils->getNewFromOriginal (IEI.getOperand (2 ));
12061204
1207- setDiffe (&IEI, Constant::getNullValue (IEI.getType ()), Builder2);
1205+ size_t size0 = 1 ;
1206+ if (orig_op0->getType ()->isSized ())
1207+ size0 =
1208+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1209+ orig_op0->getType ()) +
1210+ 7 ) /
1211+ 8 ;
1212+ size_t size1 = 1 ;
1213+ if (orig_op1->getType ()->isSized ())
1214+ size1 =
1215+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1216+ orig_op1->getType ()) +
1217+ 7 ) /
1218+ 8 ;
1219+
1220+ if (!gutils->isConstantValue (orig_op0))
1221+ addToDiffe (orig_op0,
1222+ Builder2.CreateInsertElement (
1223+ dif1, Constant::getNullValue (op1->getType ()),
1224+ lookup (op2, Builder2)),
1225+ Builder2, TR.addingType (size0, orig_op0));
1226+
1227+ if (!gutils->isConstantValue (orig_op1))
1228+ addToDiffe (orig_op1,
1229+ Builder2.CreateExtractElement (dif1, lookup (op2, Builder2)),
1230+ Builder2, TR.addingType (size1, orig_op1));
1231+
1232+ setDiffe (&IEI, Constant::getNullValue (IEI.getType ()), Builder2);
1233+ return ;
1234+ }
1235+ case DerivativeMode::ReverseModePrimal: {
1236+ return ;
1237+ }
1238+ }
12081239 }
12091240
12101241 void visitShuffleVectorInst (llvm::ShuffleVectorInst &SVI) {
0 commit comments