@@ -1242,46 +1242,77 @@ class AdjointGenerator
12421242 eraseIfUnused (SVI);
12431243 if (gutils->isConstantInstruction (&SVI))
12441244 return ;
1245- if (Mode == DerivativeMode::ReverseModePrimal)
1246- return ;
12471245
1248- IRBuilder<> Builder2 (SVI.getParent ());
1249- getReverseBuilder (Builder2);
1246+ switch (Mode) {
1247+ case DerivativeMode::ForwardMode: {
1248+ IRBuilder<> Builder2 (&SVI);
1249+ getForwardBuilder (Builder2);
1250+
1251+ Value *orig_vector1 = SVI.getOperand (0 );
1252+ Value *orig_vector2 = SVI.getOperand (1 );
1253+ Value *orig_mask = SVI.getOperand (0 );
1254+
1255+ auto diffe_vector1 =
1256+ gutils->isConstantValue (orig_vector1)
1257+ ? ConstantVector::getNullValue (orig_vector1->getType ())
1258+ : diffe (orig_vector1, Builder2);
1259+ auto diffe_vector2 =
1260+ gutils->isConstantValue (orig_vector2)
1261+ ? ConstantVector::getNullValue (orig_vector2->getType ())
1262+ : diffe (orig_vector2, Builder2);
1263+
1264+ auto diffe = Builder2.CreateShuffleVector (
1265+ diffe_vector1, diffe_vector2, gutils->getNewFromOriginal (orig_mask));
12501266
1251- auto loaded = diffe (&SVI, Builder2);
1267+ setDiffe (&SVI, diffe, Builder2);
1268+ return ;
1269+ }
1270+ case DerivativeMode::ReverseModeGradient:
1271+ case DerivativeMode::ReverseModeCombined: {
1272+ IRBuilder<> Builder2 (SVI.getParent ());
1273+ getReverseBuilder (Builder2);
1274+
1275+ auto loaded = diffe (&SVI, Builder2);
12521276#if LLVM_VERSION_MAJOR >= 12
1253- auto count =
1254- cast<VectorType>(SVI.getOperand (0 )->getType ())->getElementCount ();
1255- assert (!count.isScalable ());
1256- size_t l1 = count.getKnownMinValue ();
1277+ auto count =
1278+ cast<VectorType>(SVI.getOperand (0 )->getType ())->getElementCount ();
1279+ assert (!count.isScalable ());
1280+ size_t l1 = count.getKnownMinValue ();
12571281#else
1258- size_t l1 =
1259- cast<VectorType>(SVI.getOperand (0 )->getType ())->getNumElements ();
1282+ size_t l1 =
1283+ cast<VectorType>(SVI.getOperand (0 )->getType ())->getNumElements ();
12601284#endif
1261- uint64_t instidx = 0 ;
1285+ uint64_t instidx = 0 ;
12621286
1263- for (size_t idx : SVI.getShuffleMask ()) {
1264- auto opnum = (idx < l1) ? 0 : 1 ;
1265- auto opidx = (idx < l1) ? idx : (idx - l1);
1266- SmallVector<Value *, 4 > sv;
1267- sv.push_back (ConstantInt::get (Type::getInt32Ty (SVI.getContext ()), opidx));
1268- if (!gutils->isConstantValue (SVI.getOperand (opnum))) {
1269- size_t size = 1 ;
1270- if (SVI.getOperand (opnum)->getType ()->isSized ())
1271- size =
1272- (gutils->newFunc ->getParent ()->getDataLayout ().getTypeSizeInBits (
1273- SVI.getOperand (opnum)->getType ()) +
1274- 7 ) /
1275- 8 ;
1276- ((DiffeGradientUtils *)gutils)
1277- ->addToDiffe (SVI.getOperand (opnum),
1278- Builder2.CreateExtractElement (loaded, instidx),
1279- Builder2, TR.addingType (size, SVI.getOperand (opnum)),
1280- sv);
1287+ for (size_t idx : SVI.getShuffleMask ()) {
1288+ auto opnum = (idx < l1) ? 0 : 1 ;
1289+ auto opidx = (idx < l1) ? idx : (idx - l1);
1290+ SmallVector<Value *, 4 > sv;
1291+ sv.push_back (
1292+ ConstantInt::get (Type::getInt32Ty (SVI.getContext ()), opidx));
1293+ if (!gutils->isConstantValue (SVI.getOperand (opnum))) {
1294+ size_t size = 1 ;
1295+ if (SVI.getOperand (opnum)->getType ()->isSized ())
1296+ size = (gutils->newFunc ->getParent ()
1297+ ->getDataLayout ()
1298+ .getTypeSizeInBits (SVI.getOperand (opnum)->getType ()) +
1299+ 7 ) /
1300+ 8 ;
1301+ ((DiffeGradientUtils *)gutils)
1302+ ->addToDiffe (SVI.getOperand (opnum),
1303+ Builder2.CreateExtractElement (loaded, instidx),
1304+ Builder2, TR.addingType (size, SVI.getOperand (opnum)),
1305+ sv);
1306+ }
1307+ ++instidx;
12811308 }
1282- ++instidx;
1309+ setDiffe (&SVI, Constant::getNullValue (SVI.getType ()), Builder2);
1310+ return ;
1311+ }
1312+ case DerivativeMode::ReverseModePrimal: {
1313+ return ;
1314+ }
12831315 }
1284- setDiffe (&SVI, Constant::getNullValue (SVI.getType ()), Builder2);
12851316 }
12861317
12871318 void visitExtractValueInst (llvm::ExtractValueInst &EVI) {
0 commit comments