10
10
#include < fusion.h>
11
11
#include < fusion_profiler.h>
12
12
#include < fusion_segmenter.h>
13
+ #include < host_ir/lower.h>
13
14
#include < instrumentation.h>
14
15
#include < ir/base_nodes.h>
16
+ #include < multidevice/communication.h>
17
+ #include < multidevice/utils.h>
15
18
#include < preseg_passes/pre_segmenter.h>
16
19
#include < python_frontend/fusion_definition.h>
17
20
#include < python_frontend/translation.h>
21
24
#include < scheduler/heuristic.h>
22
25
#include < serde/fusion_cache_generated.h>
23
26
#include < type.h>
24
- #include < host_ir/lower.h>
25
- #include < multidevice/communication.h>
26
- #include < multidevice/utils.h>
27
27
28
28
#include < c10/cuda/CUDAGuard.h>
29
29
@@ -450,34 +450,37 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
450
450
hic->pushBackTopLevelExprs (launch_kernel);
451
451
} else {
452
452
const bool is_resharding = std::any_of (
453
- group_to_run->exprs ().begin (), group_to_run-> exprs (). end (), []( auto expr) {
454
- return isResharding (expr);
455
- });
453
+ group_to_run->exprs ().begin (),
454
+ group_to_run-> exprs (). end (),
455
+ []( auto expr) { return isResharding (expr); });
456
456
if (is_resharding) {
457
+ auto deviceid = Communicator::getInstance ().deviceId ();
457
458
NVF_ERROR (
458
459
group_to_run->exprs ().size () == 1 ,
459
460
" Communication segments must contain only one Expr" );
460
461
HostIrLower lower;
461
462
for (auto * expr :
462
- lower.lower (ir_cloner.clone (group_to_run->exprs ().at (0 )))) {
463
+ lower.lower (ir_cloner.clone (group_to_run->exprs ().at (0 )), deviceid )) {
463
464
// Allocate the recv buffers of communications
464
465
if (expr->isA <Communication>()) {
465
466
auto * communication = expr->as <Communication>();
466
467
TensorView* tv = communication->out ();
467
- if (tv->getDeviceMesh ().has (Communicator::getInstance (). deviceId () )) {
468
+ if (tv->getDeviceMesh ().has (deviceid )) {
468
469
auto * allocate =
469
470
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
470
471
hic->pushBackTopLevelExprs (allocate);
471
472
}
472
473
}
473
474
hic->pushBackTopLevelExprs (expr);
474
475
if (expr->isA <Communication>()) {
475
- auto wait = IrBuilder::create<hir::Wait>(expr->as <Communication>());
476
+ auto wait =
477
+ IrBuilder::create<hir::Wait>(expr->as <Communication>());
476
478
hic->pushBackTopLevelExprs (wait );
477
479
}
478
480
}
479
481
} else {
480
- // push back segment's exprs into the container as top level expressions
482
+ // push back segment's exprs into the container as top level
483
+ // expressions
481
484
for (auto * expr : group_to_run->exprs ()) {
482
485
auto cloned_expr = ir_cloner.clone (expr);
483
486
hic->pushBackTopLevelExprs (cloned_expr);
@@ -491,7 +494,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
491
494
for (const Val* out : segmented_fusion_->outputs ()) {
492
495
hic->addOutput (ir_cloner.clone (out));
493
496
}
494
-
497
+
495
498
hic->sortExprs ();
496
499
}
497
500
0 commit comments