Skip to content

Commit a2f7718

Browse files
authored
ForwardMode shufflevector inst (rust-lang#356)
1 parent d0e0331 commit a2f7718

File tree

1 file changed

+63
-32
lines changed

1 file changed

+63
-32
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,46 +1242,77 @@ class AdjointGenerator
12421242
eraseIfUnused(SVI);
12431243
if (gutils->isConstantInstruction(&SVI))
12441244
return;
1245-
if (Mode == DerivativeMode::ReverseModePrimal)
1246-
return;
12471245

1248-
IRBuilder<> Builder2(SVI.getParent());
1249-
getReverseBuilder(Builder2);
1246+
switch (Mode) {
1247+
case DerivativeMode::ForwardMode: {
1248+
IRBuilder<> Builder2(&SVI);
1249+
getForwardBuilder(Builder2);
1250+
1251+
Value *orig_vector1 = SVI.getOperand(0);
1252+
Value *orig_vector2 = SVI.getOperand(1);
1253+
Value *orig_mask = SVI.getOperand(0);
1254+
1255+
auto diffe_vector1 =
1256+
gutils->isConstantValue(orig_vector1)
1257+
? ConstantVector::getNullValue(orig_vector1->getType())
1258+
: diffe(orig_vector1, Builder2);
1259+
auto diffe_vector2 =
1260+
gutils->isConstantValue(orig_vector2)
1261+
? ConstantVector::getNullValue(orig_vector2->getType())
1262+
: diffe(orig_vector2, Builder2);
1263+
1264+
auto diffe = Builder2.CreateShuffleVector(
1265+
diffe_vector1, diffe_vector2, gutils->getNewFromOriginal(orig_mask));
12501266

1251-
auto loaded = diffe(&SVI, Builder2);
1267+
setDiffe(&SVI, diffe, Builder2);
1268+
return;
1269+
}
1270+
case DerivativeMode::ReverseModeGradient:
1271+
case DerivativeMode::ReverseModeCombined: {
1272+
IRBuilder<> Builder2(SVI.getParent());
1273+
getReverseBuilder(Builder2);
1274+
1275+
auto loaded = diffe(&SVI, Builder2);
12521276
#if LLVM_VERSION_MAJOR >= 12
1253-
auto count =
1254-
cast<VectorType>(SVI.getOperand(0)->getType())->getElementCount();
1255-
assert(!count.isScalable());
1256-
size_t l1 = count.getKnownMinValue();
1277+
auto count =
1278+
cast<VectorType>(SVI.getOperand(0)->getType())->getElementCount();
1279+
assert(!count.isScalable());
1280+
size_t l1 = count.getKnownMinValue();
12571281
#else
1258-
size_t l1 =
1259-
cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements();
1282+
size_t l1 =
1283+
cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements();
12601284
#endif
1261-
uint64_t instidx = 0;
1285+
uint64_t instidx = 0;
12621286

1263-
for (size_t idx : SVI.getShuffleMask()) {
1264-
auto opnum = (idx < l1) ? 0 : 1;
1265-
auto opidx = (idx < l1) ? idx : (idx - l1);
1266-
SmallVector<Value *, 4> sv;
1267-
sv.push_back(ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx));
1268-
if (!gutils->isConstantValue(SVI.getOperand(opnum))) {
1269-
size_t size = 1;
1270-
if (SVI.getOperand(opnum)->getType()->isSized())
1271-
size =
1272-
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1273-
SVI.getOperand(opnum)->getType()) +
1274-
7) /
1275-
8;
1276-
((DiffeGradientUtils *)gutils)
1277-
->addToDiffe(SVI.getOperand(opnum),
1278-
Builder2.CreateExtractElement(loaded, instidx),
1279-
Builder2, TR.addingType(size, SVI.getOperand(opnum)),
1280-
sv);
1287+
for (size_t idx : SVI.getShuffleMask()) {
1288+
auto opnum = (idx < l1) ? 0 : 1;
1289+
auto opidx = (idx < l1) ? idx : (idx - l1);
1290+
SmallVector<Value *, 4> sv;
1291+
sv.push_back(
1292+
ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx));
1293+
if (!gutils->isConstantValue(SVI.getOperand(opnum))) {
1294+
size_t size = 1;
1295+
if (SVI.getOperand(opnum)->getType()->isSized())
1296+
size = (gutils->newFunc->getParent()
1297+
->getDataLayout()
1298+
.getTypeSizeInBits(SVI.getOperand(opnum)->getType()) +
1299+
7) /
1300+
8;
1301+
((DiffeGradientUtils *)gutils)
1302+
->addToDiffe(SVI.getOperand(opnum),
1303+
Builder2.CreateExtractElement(loaded, instidx),
1304+
Builder2, TR.addingType(size, SVI.getOperand(opnum)),
1305+
sv);
1306+
}
1307+
++instidx;
12811308
}
1282-
++instidx;
1309+
setDiffe(&SVI, Constant::getNullValue(SVI.getType()), Builder2);
1310+
return;
1311+
}
1312+
case DerivativeMode::ReverseModePrimal: {
1313+
return;
1314+
}
12831315
}
1284-
setDiffe(&SVI, Constant::getNullValue(SVI.getType()), Builder2);
12851316
}
12861317

12871318
void visitExtractValueInst(llvm::ExtractValueInst &EVI) {

0 commit comments

Comments
 (0)