Skip to content

Commit 360b232

Browse files
authored
Do not promote broadcast only groups (#4154)
This PR attempts to not promote a loop group when it only consists of broadcast IDs. For example, consider a fusion as shown below: ``` auto tv0 = makeContigConcreteTensor({-1, 1}); fusion.addInput(tv0); auto tv1 = makeContigTensor(2); fusion.addInput(tv1); auto tv2 = set(tv0); auto tv3 = add(tv2, tv1); fusion.addOutput(tv3); for (auto tv : fusion.allTvs()) { tv->split(1, 1, false); tv->reorder({{0, 1}, {1, 0}}); } for (auto tv : fusion.allTvs()) { tv->inlineAt(2); } ``` ``` Inputs: T0_g_float[bS8{1}, iS0{i0}, bS9{1}] T1_g_float[iS12{1}, iS2{i4}, iS13{i5}] Outputs: T3_g_float[iS14{1}, iS6{i0}, iS15{i5}] ca_pos( 2 ) produce_pos( 2 ) %kernel { T2_l_float[bS10{1}, iS4{i0}, bS11{1}] ca_pos( 2 ) = Set( T0_g_float[bS8{1}, iS0{i0}, bS9{1}], cache_op=Streaming ) T3_g_float[iS14{1}, iS6{i0}, iS15{i5}] ca_pos( 2 ) produce_pos( 2 ) = T2_l_float[bS10{1}, iS4{i0}, bS11{1}] ca_pos( 2 ) + T1_g_float[iS12{1}, iS2{i4}, iS13{i5}]; TransformPrinter : T0_g_float[bS8{1}, iS0{i0}, bS9{1}] logical domain : (iS0{i0}, bS1{1}) contiguity: t n Outer split: bS1{1} by factor 1 -> bS8{1}, bS9{1} loop domain : (bS8{1}, iS0{i0}, bS9{1}) T2_l_float[bS10{1}, iS4{i0}, bS11{1}] ca_pos( 2 ) logical domain : (iS4{i0}, bS5{1}) contiguity: t n Outer split: bS5{1} by factor 1 -> bS10{1}, bS11{1} loop domain : (bS10{1}, iS4{i0}, bS11{1}) T1_g_float[iS12{1}, iS2{i4}, iS13{i5}] logical domain : (iS2{i4}, iS3{i5}) contiguity: t t Outer split: iS3{i5} by factor 1 -> iS12{1}, iS13{i5} loop domain : (iS12{1}, iS2{i4}, iS13{i5}) T3_g_float[iS14{1}, iS6{i0}, iS15{i5}] ca_pos( 2 ) produce_pos( 2 ) logical domain : (iS6{i0}, iS7{i5}) contiguity: t t Outer split: iS7{i5} by factor 1 -> iS14{1}, iS15{i5} loop domain : (iS14{1}, iS6{i0}, iS15{i5}) } // %kernel ``` Here, the interesting part is the innermost loop ID of `T2`, `bS11{1}`. Because `bS5` is promoted to `iS3` (or `iS7`), `bS11` is also promoted to a non-broadcast ID that is exact mapped with `iS13` and `iS15`. However, in this case, `bS11` doesn't really need to be promoted. More specifically, as long as a loop group only consists of broadcast IDs, the group should not need to be promoted. Currently, the generated CUDA kernel with `NVFUSER_ENABLE=id_model(all)` looks like below: ``` __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T1, Tensor<float, 2, 2> T3) { #pragma unroll 1 for(nvfuser_index_t i0 = 0LL; i0 < T0.logical_size[0LL]; ++i0) { nvfuser_index_t i1; i1 = T1.logical_size[1LL] * i0; Array<float, 1LL, 1> T2; #pragma unroll 1 for(nvfuser_index_t i2 = 0LL; i2 < T1.logical_size[1LL]; ++i2) { T2[0LL] = T0[i0]; } #pragma unroll 1 for(nvfuser_index_t i3 = 0LL; i3 < T1.logical_size[1LL]; ++i3) { nvfuser_index_t i4; i4 = i1 + i3; T3[i4] = T2[0LL] + T1[i4]; } } } ``` The code is not incorrect, but `T2` is redundantly defined over the loop of `T1.logical_size[1]` because of the promotion of `bS11`. Note that the allocation of `T2` is not affected because [broadcast IDs are excluded](https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/allocation.cpp#L299) before promotion. In this PR, for loop groups that only consist of broadcast IDs, promotion to non-broadcast is reverted. With the change, the above fusion results in the kernel below: ``` __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T1, Tensor<float, 2, 2> T3) { #pragma unroll 1 for(nvfuser_index_t i0 = 0LL; i0 < T0.logical_size[0LL]; ++i0) { nvfuser_index_t i1; i1 = T1.logical_size[1LL] * i0; Array<float, 1LL, 1> T2; T2[0LL] = T0[i0]; #pragma unroll 1 for(nvfuser_index_t i2 = 0LL; i2 < T1.logical_size[1LL]; ++i2) { nvfuser_index_t i3; i3 = i1 + i2; T3[i3] = T2[0LL] + T1[i3]; } } } ```
1 parent fa03ce6 commit 360b232

File tree

3 files changed

+130
-25
lines changed

3 files changed

+130
-25
lines changed

csrc/id_model/loop_promotion.cpp

+79-25
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <options.h>
1515
#include <val_graph_visitor.h>
1616

17+
#include <algorithm>
18+
1719
namespace nvfuser {
1820

1921
std::string toString(const CoveredGroups& covered_groups) {
@@ -106,10 +108,8 @@ bool isEqualToOrSuperSetOf(
106108
covered_groups_y.begin(),
107109
covered_groups_y.end(),
108110
[&](const CoveredGroup& covered_group_y) {
109-
return std::any_of(
110-
covered_groups_x.begin(),
111-
covered_groups_x.end(),
112-
[&](const CoveredGroup& covered_group_x) {
111+
return std::ranges::any_of(
112+
covered_groups_x, [&](const CoveredGroup& covered_group_x) {
113113
return covered_group_x.isEqualToOrSuperSetOf(covered_group_y);
114114
});
115115
});
@@ -167,10 +167,9 @@ bool isDependencyOf(
167167
return true;
168168
}
169169

170-
if (std::any_of(
171-
of->begin(), of->end(), [&](const CoveredGroup& covered_group) {
172-
return isDependencyOf(dependency, covered_group);
173-
})) {
170+
if (std::ranges::any_of(*of, [&](const CoveredGroup& covered_group) {
171+
return isDependencyOf(dependency, covered_group);
172+
})) {
174173
return true;
175174
}
176175

@@ -380,6 +379,29 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
380379
return buildWithNoBroadcast();
381380
}
382381

382+
// Keep track of IDs whose loop groups only have broadcast
383+
// IDs. These IDs should not need to be promoted to non-broadcastg
384+
// IDs. Note that we can't just remember these loop ValGroups as
385+
// they might be replaced during the following analysis.
386+
for (const auto& loop_group :
387+
idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) {
388+
if (std::ranges::any_of(*loop_group, [](Val* val) {
389+
return !val->as<IterDomain>()->isBroadcast();
390+
})) {
391+
continue;
392+
}
393+
394+
// Currently, only exact-mapped loop groups are considered. This
395+
// condition is required as we are going to replace promotion IDs
396+
// with an arbitrary member ID.
397+
if (idGraph(IdMappingMode::EXACT).toGroups(*loop_group).size() != 1) {
398+
continue;
399+
}
400+
401+
broadcast_only_loop_group_ids_.insert(
402+
loop_group->begin(), loop_group->end());
403+
}
404+
383405
// Make an intersection of the exact and loop map. This will group together
384406
// entries in each loop group that are exact with each other. This provides a
385407
// better graph to do promotion and replays.
@@ -470,6 +492,7 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
470492
if (loop_promotion_map_to_propagate.empty()) {
471493
auto final_loop_promotion_map = updateValGroupIdMap(
472494
initial_loop_promotion_map, idGraph(IdMappingMode::LOOP));
495+
revertBroadcastOnlyLoopGroups(final_loop_promotion_map);
473496
sanityCheckLoopPromotionMap(final_loop_promotion_map);
474497
return final_loop_promotion_map;
475498
}
@@ -537,6 +560,7 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
537560
final_loop_promotion_map = updateValGroupIdMap(
538561
final_loop_promotion_map, idGraph(IdMappingMode::LOOP));
539562

563+
revertBroadcastOnlyLoopGroups(final_loop_promotion_map);
540564
sanityCheckLoopPromotionMap(final_loop_promotion_map);
541565

542566
if (callback_) {
@@ -692,10 +716,8 @@ Expr* findMatchingExpr(
692716
// iel_graph, it means the domain is just replayed and by definition
693717
// has no mapping with any existing domain, which means there's no
694718
// matching expr.
695-
if (std::any_of(
696-
maybe_promoted_inputs.begin(),
697-
maybe_promoted_inputs.end(),
698-
[&](IterDomain* maybe_promoted_input) -> bool {
719+
if (std::ranges::any_of(
720+
maybe_promoted_inputs, [&](IterDomain* maybe_promoted_input) -> bool {
699721
return !iel_graph.hasGroup(maybe_promoted_input);
700722
})) {
701723
return nullptr;
@@ -972,20 +994,14 @@ LoopPromotionMapBuilder::computeCoveredGroups(
972994

973995
// Initialize broadcast groups to empty since broadcast domains
974996
// don't matter for indexing
975-
if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) {
997+
if (std::ranges::any_of(*id_group, [&](Val* id) {
976998
return id->as<IterDomain>()->isBroadcast();
977999
})) {
9781000
covered_group_map[id_group] = std::make_shared<CoveredGroups>();
9791001
}
9801002
}
9811003

9821004
ValGraphStmtSort exact_stmt_sort(exact_graph, input_groups_of_graph);
983-
#if 0
984-
std::cerr << "Sorted exprs:\n";
985-
for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
986-
std::cerr << exact_expr->front()->toString();
987-
}
988-
#endif
9891005
for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
9901006
const std::vector<ValGroup> input_groups =
9911007
exact_graph.inputGroups(exact_expr);
@@ -1180,12 +1196,9 @@ VectorOfUniqueEntries<IterDomain*> LoopPromotionMapBuilder::
11801196
// then it's a terminal ID
11811197
bool all_outs_in_loop_group = true;
11821198
for (auto use : uses_it->second) {
1183-
if (std::any_of(
1184-
use->outputs().begin(),
1185-
use->outputs().end(),
1186-
[&](Val* out) -> bool {
1187-
return group != idGraph(IdMappingMode::LOOP).toGroup(out);
1188-
})) {
1199+
if (std::ranges::any_of(use->outputs(), [&](Val* out) -> bool {
1200+
return group != idGraph(IdMappingMode::LOOP).toGroup(out);
1201+
})) {
11891202
all_outs_in_loop_group = false;
11901203
break;
11911204
}
@@ -1293,4 +1306,45 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
12931306
return map;
12941307
}
12951308

1309+
void LoopPromotionMapBuilder::revertBroadcastOnlyLoopGroups(
1310+
std::unordered_map<ValGroup, IterDomain*>& loop_promotion_map) const {
1311+
// If a loop group originally only consisted of broadcast IDs
1312+
// and now is promoted to a non-broadcast ID, it should not need to
1313+
// be promoted.
1314+
for (auto& [loop_group, current_promotion_id] : loop_promotion_map) {
1315+
if (current_promotion_id->isBroadcast()) {
1316+
continue;
1317+
}
1318+
1319+
// As long as there's a single ID marked as broadcast only, this
1320+
// group originally consisted of broadcast IDs only. Note that,
1321+
// since new IDs were added as part of the promotion analysis, not
1322+
// all of the IDs are included in the broadcast only set.
1323+
IterDomain* original_broadcast_id = nullptr;
1324+
for (auto val : *loop_group) {
1325+
if (broadcast_only_loop_group_ids_.contains(val)) {
1326+
original_broadcast_id = val->as<IterDomain>();
1327+
break;
1328+
}
1329+
}
1330+
if (original_broadcast_id == nullptr) {
1331+
continue;
1332+
}
1333+
1334+
// Note that this promotion should be valid for the existing
1335+
// IDs that originate from the fusion, but the loop group also
1336+
// contains other non-broadcast IDs for loop promotion, e.g.,
1337+
// current_promotion_id. This replacement means those
1338+
// non-broadcast IDs are also promoted to the broadcast ID, which
1339+
// does not make sense. For example, in the case of
1340+
// IdModelTest.BroadcastOnlyNoLoopPromotion, the innermost loop ID
1341+
// of tv2 has no mapping in the original fusion, but its loop
1342+
// group gets additional IDs, iS17 and iS19, both of which are
1343+
// exact mapped with the innermost loop IDs of tv1 and tv3.
1344+
//
1345+
// TODO: Consider cleaning up the unused non-broadcast IDs.
1346+
current_promotion_id = original_broadcast_id;
1347+
}
1348+
}
1349+
12961350
} // namespace nvfuser

csrc/id_model/loop_promotion.h

+6
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,16 @@ class LoopPromotionMapBuilder {
293293
const std::unordered_map<ValGroup, IterDomain*>& loop_promotion_map)
294294
const;
295295

296+
// Revert unnecessary promotions to non-broadcast IDs
297+
void revertBroadcastOnlyLoopGroups(
298+
std::unordered_map<ValGroup, IterDomain*>& loop_promotion_map) const;
299+
296300
private:
297301
IdModel& id_model_;
298302
const StatefulInliningInfo& inlining_info_;
299303
LoopPromotionMapBuilderCallback* callback_ = nullptr;
304+
// Keep track of IDs of broadcast only loop groups
305+
std::unordered_set<Val*> broadcast_only_loop_group_ids_;
300306

301307
// (For debugging only) When force_full_loop_promotion_analysis_ is
302308
// true, it always performs the full loop promotion analysis even

tests/cpp/test_id_model.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -3084,4 +3084,49 @@ TEST_F(IdModelTest, InvalidLoopPromotion) {
30843084
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
30853085
}
30863086

3087+
// When a loop group only includes broadcast IDs, the group should not
3088+
// need to be promoted
3089+
TEST_F(IdModelTest, BroadcastOnlyNoLoopPromotion) {
3090+
auto fusion_ptr = std::make_unique<Fusion>();
3091+
auto& fusion = *fusion_ptr;
3092+
FusionGuard fg(fusion_ptr.get());
3093+
3094+
auto tv0 = makeContigConcreteTensor({-1, 1});
3095+
fusion.addInput(tv0);
3096+
auto tv1 = makeContigTensor(2);
3097+
fusion.addInput(tv1);
3098+
3099+
auto tv2 = set(tv0);
3100+
auto tv3 = add(tv2, tv1);
3101+
fusion.addOutput(tv3);
3102+
3103+
for (auto tv : fusion.allTvs()) {
3104+
tv->split(1, 1, false);
3105+
tv->reorder({{0, 1}, {1, 0}});
3106+
}
3107+
3108+
for (auto tv : fusion.allTvs()) {
3109+
tv->inlineAt(2);
3110+
}
3111+
3112+
// T2_l_float[bS10{1}, iS4{i0}, bS11{1}] ca_pos( 2 )
3113+
// = Set( T0_g_float[bS8{1}, iS0{i0}, bS9{1}], cache_op=Streaming )
3114+
// T3_g_float[iS14{1}, iS6{i0}, iS15{i5}] ca_pos( 2 ) produce_pos( 2 )
3115+
// = T2_l_float[bS10{1}, iS4{i0}, bS11{1}] ca_pos( 2 )
3116+
// + T1_g_float[iS12{1}, iS2{i4}, iS13{i5}];
3117+
3118+
// In this fusion, the innermost loop ID of tv2 is broadcast and is
3119+
// not inlined. While its producer ID is promoted to the concrete
3120+
// logical ID of tv3, it should not need to promote the loop ID as
3121+
// it's just a broadcast.
3122+
3123+
IdModel id_model(&fusion, /*build_graphs=*/true);
3124+
3125+
auto promotion_id = id_model.loopPromotionMap().at(
3126+
id_model.idGraph(IdMappingMode::LOOP).toGroup(tv2->axis(-1)));
3127+
EXPECT_TRUE(promotion_id->isBroadcast())
3128+
<< "Should not be promoted a non-broadcast ID: "
3129+
<< promotion_id->toString();
3130+
}
3131+
30873132
} // namespace nvfuser

0 commit comments

Comments
 (0)