@@ -1260,33 +1260,56 @@ class AdjointGenerator
12601260 if (EVI.getType ()->isPointerTy ())
12611261 return ;
12621262
1263- if (Mode == DerivativeMode::ReverseModePrimal)
1263+ switch (Mode) {
1264+ case DerivativeMode::ForwardMode: {
1265+ IRBuilder<> Builder2 (&EVI);
1266+ getForwardBuilder (Builder2);
1267+
1268+ Value *orig_aggregate = EVI.getAggregateOperand ();
1269+
1270+ Value *diffe_aggregate =
1271+ gutils->isConstantValue (orig_aggregate)
1272+ ? ConstantAggregate::getNullValue (orig_aggregate->getType ())
1273+ : diffe (orig_aggregate, Builder2);
1274+ Value *diffe =
1275+ Builder2.CreateExtractValue (diffe_aggregate, EVI.getIndices ());
1276+
1277+ setDiffe (&EVI, diffe, Builder2);
12641278 return ;
1279+ }
1280+ case DerivativeMode::ReverseModeGradient:
1281+ case DerivativeMode::ReverseModeCombined: {
1282+ IRBuilder<> Builder2 (EVI.getParent ());
1283+ getReverseBuilder (Builder2);
12651284
1266- Value *orig_op0 = EVI.getOperand (0 );
1285+ Value *orig_op0 = EVI.getOperand (0 );
12671286
1268- IRBuilder<> Builder2 (EVI.getParent ());
1269- getReverseBuilder (Builder2);
1287+ auto prediff = diffe (&EVI, Builder2);
12701288
1271- auto prediff = diffe (&EVI, Builder2);
1289+ // todo const
1290+ if (!gutils->isConstantValue (orig_op0)) {
1291+ SmallVector<Value *, 4 > sv;
1292+ for (auto i : EVI.getIndices ())
1293+ sv.push_back (ConstantInt::get (Type::getInt32Ty (EVI.getContext ()), i));
1294+ size_t size = 1 ;
1295+ if (EVI.getType ()->isSized ())
1296+ size =
1297+ (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1298+ EVI.getType ()) +
1299+ 7 ) /
1300+ 8 ;
1301+ ((DiffeGradientUtils *)gutils)
1302+ ->addToDiffe (orig_op0, prediff, Builder2, TR.addingType (size, &EVI),
1303+ sv);
1304+ }
12721305
1273- // todo const
1274- if (!gutils->isConstantValue (orig_op0)) {
1275- SmallVector<Value *, 4 > sv;
1276- for (auto i : EVI.getIndices ())
1277- sv.push_back (ConstantInt::get (Type::getInt32Ty (EVI.getContext ()), i));
1278- size_t size = 1 ;
1279- if (EVI.getType ()->isSized ())
1280- size = (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1281- EVI.getType ()) +
1282- 7 ) /
1283- 8 ;
1284- ((DiffeGradientUtils *)gutils)
1285- ->addToDiffe (orig_op0, prediff, Builder2, TR.addingType (size, &EVI),
1286- sv);
1306+ setDiffe (&EVI, Constant::getNullValue (EVI.getType ()), Builder2);
1307+ return ;
1308+ }
1309+ case DerivativeMode::ReverseModePrimal: {
1310+ return ;
1311+ }
12871312 }
1288-
1289- setDiffe (&EVI, Constant::getNullValue (EVI.getType ()), Builder2);
12901313 }
12911314
12921315 void visitInsertValueInst (llvm::InsertValueInst &IVI) {
0 commit comments