Skip to content

Commit 10daa92

Browse files
committed
add host ir support for set reduce and binary op
1 parent 045d02b commit 10daa92

File tree

5 files changed

+285
-131
lines changed

5 files changed

+285
-131
lines changed

csrc/host_ir/executor.cpp

+99
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,31 @@ void HostIrEvaluator::handle(LinearOp* linear) {
570570
}
571571
}
572572

573+
void HostIrEvaluator::handle(LoadStoreOp* load_store_op) {
574+
NVF_ERROR(
575+
load_store_op->out()->isA<TensorView>(), "out must be a TensorView");
576+
auto* out_tv = load_store_op->out()->as<TensorView>();
577+
auto in_tensor = getKnownConcreteData(load_store_op->in()).as<at::Tensor>();
578+
579+
// If output has root domain, compute and apply permutation
580+
if (out_tv->hasRoot()) {
581+
auto permutation = ir_utils::computePermutation(
582+
out_tv->getRootDomain(), out_tv->getLogicalDomain());
583+
NVF_ERROR(
584+
permutation.has_value(),
585+
"The logical domain of a Set.Permute is supposed to be a permutation of the root domain: ",
586+
out_tv->toString());
587+
in_tensor = in_tensor.permute(*permutation).contiguous();
588+
}
589+
if (!isKnown(load_store_op->out())) {
590+
bind(load_store_op->out(), in_tensor);
591+
} else {
592+
auto out_tensor =
593+
getKnownConcreteData(load_store_op->out()).as<at::Tensor>();
594+
out_tensor.copy_(in_tensor);
595+
}
596+
}
597+
573598
void HostIrEvaluator::handle(kir::Allocate* allocate) {
574599
NVF_ERROR(
575600
allocate->buffer()->isA<TensorView>(),
@@ -593,6 +618,80 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
593618
bind(tv, tensor);
594619
}
595620

621+
void HostIrEvaluator::handle(BinaryOp* binary_op) {
622+
if (!isKnown(binary_op->outputs().at(0))) {
623+
return unhandled(binary_op);
624+
}
625+
626+
auto lhs = getKnownConcreteData(binary_op->inputs().at(0)).as<at::Tensor>();
627+
auto rhs = getKnownConcreteData(binary_op->inputs().at(1)).as<at::Tensor>();
628+
auto output =
629+
getKnownConcreteData(binary_op->outputs().at(0)).as<at::Tensor>();
630+
631+
switch (binary_op->getBinaryOpType()) {
632+
case BinaryOpType::Add:
633+
at::add_out(output, lhs, rhs);
634+
break;
635+
case BinaryOpType::Sub:
636+
at::sub_out(output, lhs, rhs);
637+
break;
638+
case BinaryOpType::Mul:
639+
at::mul_out(output, lhs, rhs);
640+
break;
641+
case BinaryOpType::Div:
642+
at::div_out(output, lhs, rhs);
643+
break;
644+
default:
645+
NVF_CHECK(
646+
false,
647+
"Unexpected operator type: ",
648+
binary_op->getBinaryOpType(),
649+
" in ",
650+
binary_op);
651+
}
652+
}
653+
654+
void HostIrEvaluator::handle(ReductionOp* reduction_op) {
655+
auto input_tv = reduction_op->in()->as<TensorView>();
656+
auto output_tv = reduction_op->out()->as<TensorView>();
657+
if (!isKnown(output_tv)) {
658+
return unhandled(reduction_op);
659+
}
660+
661+
NVF_ERROR(
662+
!output_tv->hasRoot(),
663+
"Evaluation for rFactored reductions is not supported.");
664+
auto input = getKnownConcreteData(input_tv).as<at::Tensor>();
665+
auto output = getKnownConcreteData(output_tv).as<at::Tensor>();
666+
667+
std::vector<int64_t> reduction_axes;
668+
for (const auto i :
669+
c10::irange(int64_t(output_tv->getLogicalDomain().size()))) {
670+
auto ax = output_tv->getLogicalDomain().at(i);
671+
if (ax->isReduction()) {
672+
reduction_axes.push_back(i);
673+
}
674+
}
675+
switch (reduction_op->getReductionOpType()) {
676+
case BinaryOpType::Add:
677+
at::sum_out(output, input, reduction_axes);
678+
return;
679+
case BinaryOpType::Max:
680+
at::amax_out(output, input, reduction_axes);
681+
return;
682+
case BinaryOpType::Min:
683+
at::amin_out(output, input, reduction_axes);
684+
return;
685+
default:
686+
NVF_CHECK(
687+
false,
688+
"Unexpected operator type: ",
689+
reduction_op->getReductionOpType(),
690+
" in ",
691+
reduction_op);
692+
}
693+
}
694+
596695
void HostIrEvaluator::unhandled(Statement* stmt) {
597696
NVF_ERROR(stmt->isA<Expr>(), stmt, " must be an Expr");
598697
auto* expr = stmt->as<Expr>();

csrc/host_ir/executor.h

+3
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class HostIrEvaluator final : public OptOutDispatch {
129129
void handle(MatmulOp* matmul) override;
130130
void handle(LinearOp* linear) override;
131131
void handle(kir::Allocate* allocate) override;
132+
void handle(LoadStoreOp* load_store_op) override;
133+
void handle(BinaryOp* binary_op) override;
134+
void handle(ReductionOp* reduction_op) override;
132135
void unhandled(Statement* stmt) override;
133136

134137
c10::cuda::CUDAStream getCUDAStream(Stream* stream);

csrc/host_ir/lower.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ bool HostIrLower::isLoweredAsStandaloneHostOp(Expr* expr) {
597597
SliceOp,
598598
SelectOp,
599599
LinearOp,
600+
LoadStoreOp,
601+
BinaryOp,
602+
ReductionOp,
600603
Communication,
601604
P2PCommunication>();
602605
}

tests/cpp/test_host_irs.cpp

+180
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,186 @@ TEST_F(HirAlias, ThrowOnInputAlias) {
12761276
EXPECT_ANY_THROW(HostIrEvaluator hie(std::move(hic)));
12771277
}
12781278

1279+
using HirSetTest = NVFuserTest;
1280+
1281+
TEST_F(HirSetTest, HostIr) {
1282+
const std::vector<int64_t> sizes = {8, 64};
1283+
1284+
auto hic = std::make_unique<HostIrContainer>();
1285+
FusionGuard fg(hic.get());
1286+
1287+
auto* in = makeConcreteTensor(sizes);
1288+
auto* out = makeConcreteTensor(sizes);
1289+
auto* set = IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, in);
1290+
hic->addInput(in);
1291+
hic->addInput(out);
1292+
hic->pushBackTopLevelExprs(set);
1293+
1294+
HostIrEvaluator hie(std::move(hic));
1295+
1296+
auto options = at::TensorOptions().device(at::kCUDA, 0);
1297+
auto in_aten = at::randn(sizes, options);
1298+
auto out_aten = at::empty(sizes, options);
1299+
1300+
hie.runWithInput({{in, in_aten}, {out, out_aten}});
1301+
1302+
EXPECT_TRUE(out_aten.equal(in_aten))
1303+
<< "Obtained output: " << out_aten << "\n"
1304+
<< "Expected output: " << in_aten;
1305+
}
1306+
1307+
class HirBinaryOpTest : public NVFuserFixtureParamTest<BinaryOpType> {
1308+
protected:
1309+
at::Tensor executeBinaryOp(at::Tensor lhs, at::Tensor rhs) {
1310+
switch (GetParam()) {
1311+
case BinaryOpType::Add:
1312+
return lhs + rhs;
1313+
case BinaryOpType::Sub:
1314+
return lhs - rhs;
1315+
case BinaryOpType::Mul:
1316+
return lhs * rhs;
1317+
case BinaryOpType::Div:
1318+
return lhs / rhs;
1319+
default:
1320+
NVF_ERROR("Unsupported binary op type ", GetParam());
1321+
return at::Tensor();
1322+
}
1323+
}
1324+
};
1325+
1326+
TEST_P(HirBinaryOpTest, PreAllocatedOutputs) {
1327+
const std::vector<int64_t> sizes = {8, 64};
1328+
const auto& binary_op_type = GetParam();
1329+
1330+
auto hic = std::make_unique<HostIrContainer>();
1331+
FusionGuard fg(hic.get());
1332+
1333+
auto* lhs = makeConcreteTensor(sizes);
1334+
auto* rhs = makeConcreteTensor(sizes);
1335+
auto* out = makeConcreteTensor(sizes);
1336+
auto* binary_op = IrBuilder::create<BinaryOp>(binary_op_type, out, lhs, rhs);
1337+
hic->addInput(lhs);
1338+
hic->addInput(rhs);
1339+
hic->addInput(out);
1340+
hic->pushBackTopLevelExprs(binary_op);
1341+
1342+
HostIrEvaluator hie(std::move(hic));
1343+
1344+
auto options = at::TensorOptions().device(at::kCUDA, 0);
1345+
auto lhs_aten = at::randn(sizes, options);
1346+
auto rhs_aten = at::randn(sizes, options);
1347+
auto out_aten = at::empty(sizes, options);
1348+
1349+
hie.runWithInput({{lhs, lhs_aten}, {rhs, rhs_aten}, {out, out_aten}});
1350+
1351+
at::Tensor expected_out = executeBinaryOp(lhs_aten, rhs_aten);
1352+
EXPECT_TRUE(expected_out.equal(out_aten))
1353+
<< "Obtained output: " << out_aten << "\n"
1354+
<< "Expected output: " << expected_out;
1355+
}
1356+
1357+
TEST_P(HirBinaryOpTest, NonPreAllocatedOutputs) {
1358+
const std::vector<int64_t> sizes = {8, 64};
1359+
const auto& binary_op_type = GetParam();
1360+
1361+
auto hic = std::make_unique<HostIrContainer>();
1362+
FusionGuard fg(hic.get());
1363+
1364+
auto* lhs = makeConcreteTensor(sizes);
1365+
auto* rhs = makeConcreteTensor(sizes);
1366+
auto* out = binaryOp(binary_op_type, lhs, rhs);
1367+
hic->addInput(lhs);
1368+
hic->addInput(rhs);
1369+
hic->addOutput(out);
1370+
hic->pushBackTopLevelExprs(out->definition());
1371+
1372+
HostIrEvaluator hie(std::move(hic));
1373+
1374+
auto options = at::TensorOptions().device(at::kCUDA, 0);
1375+
auto lhs_aten = at::randn(sizes, options);
1376+
auto rhs_aten = at::randn(sizes, options);
1377+
1378+
auto out_aten =
1379+
hie.runWithInput({{lhs, lhs_aten}, {rhs, rhs_aten}})[0].as<at::Tensor>();
1380+
1381+
at::Tensor expected_out = executeBinaryOp(lhs_aten, rhs_aten);
1382+
EXPECT_TRUE(expected_out.equal(out_aten))
1383+
<< "Obtained output: " << out_aten << "\n"
1384+
<< "Expected output: " << expected_out;
1385+
}
1386+
1387+
INSTANTIATE_TEST_SUITE_P(
1388+
,
1389+
HirBinaryOpTest,
1390+
testing::Values(
1391+
BinaryOpType::Add,
1392+
BinaryOpType::Sub,
1393+
BinaryOpType::Mul,
1394+
BinaryOpType::Div),
1395+
[](const testing::TestParamInfo<BinaryOpType>& info) -> std::string {
1396+
std::stringstream ss;
1397+
ss << "BinaryOpType_" << info.param;
1398+
return ss.str();
1399+
});
1400+
1401+
using HirReductionOpTest = NVFuserTest;
1402+
1403+
TEST_F(HirReductionOpTest, PreAllocatedOutputs) {
1404+
constexpr int64_t size0 = 8, size1 = 64;
1405+
constexpr int64_t reduction_axis = 1;
1406+
1407+
auto hic = std::make_unique<HostIrContainer>();
1408+
FusionGuard fg(hic.get());
1409+
1410+
auto* in = makeConcreteTensor({size0, size1});
1411+
auto* out = newForReduction(in, {reduction_axis}, in->dtype());
1412+
auto* reduction_op = IrBuilder::create<ReductionOp>(
1413+
BinaryOpType::Add, hic->zeroVal(), out, in);
1414+
hic->addInput(in);
1415+
hic->addOutput(out);
1416+
hic->pushBackTopLevelExprs(reduction_op);
1417+
1418+
HostIrEvaluator hie(std::move(hic));
1419+
1420+
auto options = at::TensorOptions().device(at::kCUDA, 0);
1421+
auto in_aten = at::randn({size0, size1}, options);
1422+
auto out_aten = at::empty({size0}, options);
1423+
1424+
hie.runWithInput({{in, in_aten}, {out, out_aten}});
1425+
1426+
at::Tensor expected_out = in_aten.sum(reduction_axis);
1427+
EXPECT_TRUE(expected_out.equal(out_aten))
1428+
<< "Obtained output: " << out_aten << "\n"
1429+
<< "Expected output: " << expected_out;
1430+
}
1431+
1432+
TEST_F(HirReductionOpTest, NonPreAllocatedOutputs) {
1433+
constexpr int64_t size0 = 8, size1 = 64;
1434+
constexpr int64_t reduction_axis = 1;
1435+
1436+
auto hic = std::make_unique<HostIrContainer>();
1437+
FusionGuard fg(hic.get());
1438+
1439+
auto* in = makeConcreteTensor({size0, size1});
1440+
auto* out = sum(in, {reduction_axis});
1441+
hic->addInput(in);
1442+
hic->addOutput(out);
1443+
hic->pushBackTopLevelExprs(out->definition());
1444+
1445+
HostIrEvaluator hie(std::move(hic));
1446+
1447+
auto options = at::TensorOptions().device(at::kCUDA, 0);
1448+
auto in_aten = at::randn({size0, size1}, options);
1449+
auto out_aten = at::empty({size0}, options);
1450+
1451+
hie.runWithInput({{in, in_aten}, {out, out_aten}});
1452+
1453+
at::Tensor expected_out = in_aten.sum(reduction_axis);
1454+
EXPECT_TRUE(expected_out.equal(out_aten))
1455+
<< "Obtained output: " << out_aten << "\n"
1456+
<< "Expected output: " << expected_out;
1457+
}
1458+
12791459
} // namespace hir
12801460

12811461
} // namespace nvfuser

0 commit comments

Comments
 (0)