Skip to content

Commit 3ef189d

Browse files
committed
fix
1 parent 401d991 commit 3ef189d

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

csrc/fusion_segmenter.h

-2
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ struct SegmentedEdge {
5050
}
5151
};
5252

53-
5453
std::vector<Expr*> groupExprPrintSorting(const std::vector<Expr*>& exprs);
5554

56-
5755
std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge);
5856

5957
//! Groups together expressions which create a segmented group

csrc/host_ir/container.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77
// clang-format on
88

9+
#include <fusion_segmenter.h>
910
#include <host_ir/container.h>
1011
#include <host_ir/host_ir.h>
1112
#include <ir/builder.h>
@@ -15,7 +16,6 @@
1516
#include <kernel_ir.h>
1617
#include <ops/all_ops.h>
1718
#include <runtime/executor.h>
18-
#include <fusion_segmenter.h>
1919

2020
namespace nvfuser {
2121

@@ -55,8 +55,7 @@ void HostIrContainer::setKernelExecutor(
5555
kernel_executors_.at(index) = std::move(ke);
5656
}
5757

58-
void HostIrContainer::sortExprs()
59-
{
58+
void HostIrContainer::sortExprs() {
6059
this->top_level_exprs_ = groupExprPrintSorting(this->top_level_exprs_);
6160
}
6261

csrc/runtime/fusion_kernel_runtime.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
#include <fusion.h>
1111
#include <fusion_profiler.h>
1212
#include <fusion_segmenter.h>
13+
#include <host_ir/lower.h>
1314
#include <instrumentation.h>
1415
#include <ir/base_nodes.h>
16+
#include <multidevice/communication.h>
17+
#include <multidevice/utils.h>
1518
#include <preseg_passes/pre_segmenter.h>
1619
#include <python_frontend/fusion_definition.h>
1720
#include <python_frontend/translation.h>
@@ -21,9 +24,6 @@
2124
#include <scheduler/heuristic.h>
2225
#include <serde/fusion_cache_generated.h>
2326
#include <type.h>
24-
#include <host_ir/lower.h>
25-
#include <multidevice/communication.h>
26-
#include <multidevice/utils.h>
2727

2828
#include <c10/cuda/CUDAGuard.h>
2929

@@ -450,34 +450,37 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
450450
hic->pushBackTopLevelExprs(launch_kernel);
451451
} else {
452452
const bool is_resharding = std::any_of(
453-
group_to_run->exprs().begin(), group_to_run->exprs().end(), [](auto expr) {
454-
return isResharding(expr);
455-
});
453+
group_to_run->exprs().begin(),
454+
group_to_run->exprs().end(),
455+
[](auto expr) { return isResharding(expr); });
456456
if (is_resharding) {
457+
auto deviceid = Communicator::getInstance().deviceId();
457458
NVF_ERROR(
458459
group_to_run->exprs().size() == 1,
459460
"Communication segments must contain only one Expr");
460461
HostIrLower lower;
461462
for (auto* expr :
462-
lower.lower(ir_cloner.clone(group_to_run->exprs().at(0)))) {
463+
lower.lower(ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) {
463464
// Allocate the recv buffers of communications
464465
if (expr->isA<Communication>()) {
465466
auto* communication = expr->as<Communication>();
466467
TensorView* tv = communication->out();
467-
if (tv->getDeviceMesh().has(Communicator::getInstance().deviceId())) {
468+
if (tv->getDeviceMesh().has(deviceid)) {
468469
auto* allocate =
469470
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
470471
hic->pushBackTopLevelExprs(allocate);
471472
}
472473
}
473474
hic->pushBackTopLevelExprs(expr);
474475
if (expr->isA<Communication>()) {
475-
auto wait = IrBuilder::create<hir::Wait>(expr->as<Communication>());
476+
auto wait =
477+
IrBuilder::create<hir::Wait>(expr->as<Communication>());
476478
hic->pushBackTopLevelExprs(wait);
477479
}
478480
}
479481
} else {
480-
// push back segment's exprs into the container as top level expressions
482+
// push back segment's exprs into the container as top level
483+
// expressions
481484
for (auto* expr : group_to_run->exprs()) {
482485
auto cloned_expr = ir_cloner.clone(expr);
483486
hic->pushBackTopLevelExprs(cloned_expr);
@@ -491,7 +494,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
491494
for (const Val* out : segmented_fusion_->outputs()) {
492495
hic->addOutput(ir_cloner.clone(out));
493496
}
494-
497+
495498
hic->sortExprs();
496499
}
497500

0 commit comments

Comments
 (0)