|
14 | 14 | #include <options.h>
|
15 | 15 | #include <val_graph_visitor.h>
|
16 | 16 |
|
| 17 | +#include <algorithm> |
| 18 | + |
17 | 19 | namespace nvfuser {
|
18 | 20 |
|
19 | 21 | std::string toString(const CoveredGroups& covered_groups) {
|
@@ -106,10 +108,8 @@ bool isEqualToOrSuperSetOf(
|
106 | 108 | covered_groups_y.begin(),
|
107 | 109 | covered_groups_y.end(),
|
108 | 110 | [&](const CoveredGroup& covered_group_y) {
|
109 |
| - return std::any_of( |
110 |
| - covered_groups_x.begin(), |
111 |
| - covered_groups_x.end(), |
112 |
| - [&](const CoveredGroup& covered_group_x) { |
| 111 | + return std::ranges::any_of( |
| 112 | + covered_groups_x, [&](const CoveredGroup& covered_group_x) { |
113 | 113 | return covered_group_x.isEqualToOrSuperSetOf(covered_group_y);
|
114 | 114 | });
|
115 | 115 | });
|
@@ -167,10 +167,9 @@ bool isDependencyOf(
|
167 | 167 | return true;
|
168 | 168 | }
|
169 | 169 |
|
170 |
| - if (std::any_of( |
171 |
| - of->begin(), of->end(), [&](const CoveredGroup& covered_group) { |
172 |
| - return isDependencyOf(dependency, covered_group); |
173 |
| - })) { |
| 170 | + if (std::ranges::any_of(*of, [&](const CoveredGroup& covered_group) { |
| 171 | + return isDependencyOf(dependency, covered_group); |
| 172 | + })) { |
174 | 173 | return true;
|
175 | 174 | }
|
176 | 175 |
|
@@ -380,6 +379,29 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
|
380 | 379 | return buildWithNoBroadcast();
|
381 | 380 | }
|
382 | 381 |
|
| 382 | + // Keep track of IDs whose loop groups only have broadcast |
| 383 | + // IDs. These IDs should not need to be promoted to non-broadcastg |
| 384 | + // IDs. Note that we can't just remember these loop ValGroups as |
| 385 | + // they might be replaced during the following analysis. |
| 386 | + for (const auto& loop_group : |
| 387 | + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { |
| 388 | + if (std::ranges::any_of(*loop_group, [](Val* val) { |
| 389 | + return !val->as<IterDomain>()->isBroadcast(); |
| 390 | + })) { |
| 391 | + continue; |
| 392 | + } |
| 393 | + |
| 394 | + // Currently, only exact-mapped loop groups are considered. This |
| 395 | + // condition is required as we are going to replace promotion IDs |
| 396 | + // with an arbitrary member ID. |
| 397 | + if (idGraph(IdMappingMode::EXACT).toGroups(*loop_group).size() != 1) { |
| 398 | + continue; |
| 399 | + } |
| 400 | + |
| 401 | + broadcast_only_loop_group_ids_.insert( |
| 402 | + loop_group->begin(), loop_group->end()); |
| 403 | + } |
| 404 | + |
383 | 405 | // Make an intersection of the exact and loop map. This will group together
|
384 | 406 | // entries in each loop group that are exact with each other. This provides a
|
385 | 407 | // better graph to do promotion and replays.
|
@@ -470,6 +492,7 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
|
470 | 492 | if (loop_promotion_map_to_propagate.empty()) {
|
471 | 493 | auto final_loop_promotion_map = updateValGroupIdMap(
|
472 | 494 | initial_loop_promotion_map, idGraph(IdMappingMode::LOOP));
|
| 495 | + revertBroadcastOnlyLoopGroups(final_loop_promotion_map); |
473 | 496 | sanityCheckLoopPromotionMap(final_loop_promotion_map);
|
474 | 497 | return final_loop_promotion_map;
|
475 | 498 | }
|
@@ -537,6 +560,7 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
|
537 | 560 | final_loop_promotion_map = updateValGroupIdMap(
|
538 | 561 | final_loop_promotion_map, idGraph(IdMappingMode::LOOP));
|
539 | 562 |
|
| 563 | + revertBroadcastOnlyLoopGroups(final_loop_promotion_map); |
540 | 564 | sanityCheckLoopPromotionMap(final_loop_promotion_map);
|
541 | 565 |
|
542 | 566 | if (callback_) {
|
@@ -692,10 +716,8 @@ Expr* findMatchingExpr(
|
692 | 716 | // iel_graph, it means the domain is just replayed and by definition
|
693 | 717 | // has no mapping with any existing domain, which means there's no
|
694 | 718 | // matching expr.
|
695 |
| - if (std::any_of( |
696 |
| - maybe_promoted_inputs.begin(), |
697 |
| - maybe_promoted_inputs.end(), |
698 |
| - [&](IterDomain* maybe_promoted_input) -> bool { |
| 719 | + if (std::ranges::any_of( |
| 720 | + maybe_promoted_inputs, [&](IterDomain* maybe_promoted_input) -> bool { |
699 | 721 | return !iel_graph.hasGroup(maybe_promoted_input);
|
700 | 722 | })) {
|
701 | 723 | return nullptr;
|
@@ -972,20 +994,14 @@ LoopPromotionMapBuilder::computeCoveredGroups(
|
972 | 994 |
|
973 | 995 | // Initialize broadcast groups to empty since broadcast domains
|
974 | 996 | // don't matter for indexing
|
975 |
| - if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) { |
| 997 | + if (std::ranges::any_of(*id_group, [&](Val* id) { |
976 | 998 | return id->as<IterDomain>()->isBroadcast();
|
977 | 999 | })) {
|
978 | 1000 | covered_group_map[id_group] = std::make_shared<CoveredGroups>();
|
979 | 1001 | }
|
980 | 1002 | }
|
981 | 1003 |
|
982 | 1004 | ValGraphStmtSort exact_stmt_sort(exact_graph, input_groups_of_graph);
|
983 |
| -#if 0 |
984 |
| - std::cerr << "Sorted exprs:\n"; |
985 |
| - for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) { |
986 |
| - std::cerr << exact_expr->front()->toString(); |
987 |
| - } |
988 |
| -#endif |
989 | 1005 | for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
|
990 | 1006 | const std::vector<ValGroup> input_groups =
|
991 | 1007 | exact_graph.inputGroups(exact_expr);
|
@@ -1180,12 +1196,9 @@ VectorOfUniqueEntries<IterDomain*> LoopPromotionMapBuilder::
|
1180 | 1196 | // then it's a terminal ID
|
1181 | 1197 | bool all_outs_in_loop_group = true;
|
1182 | 1198 | for (auto use : uses_it->second) {
|
1183 |
| - if (std::any_of( |
1184 |
| - use->outputs().begin(), |
1185 |
| - use->outputs().end(), |
1186 |
| - [&](Val* out) -> bool { |
1187 |
| - return group != idGraph(IdMappingMode::LOOP).toGroup(out); |
1188 |
| - })) { |
| 1199 | + if (std::ranges::any_of(use->outputs(), [&](Val* out) -> bool { |
| 1200 | + return group != idGraph(IdMappingMode::LOOP).toGroup(out); |
| 1201 | + })) { |
1189 | 1202 | all_outs_in_loop_group = false;
|
1190 | 1203 | break;
|
1191 | 1204 | }
|
@@ -1293,4 +1306,45 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
|
1293 | 1306 | return map;
|
1294 | 1307 | }
|
1295 | 1308 |
|
| 1309 | +void LoopPromotionMapBuilder::revertBroadcastOnlyLoopGroups( |
| 1310 | + std::unordered_map<ValGroup, IterDomain*>& loop_promotion_map) const { |
| 1311 | + // If a loop group originally only consisted of broadcast IDs |
| 1312 | + // and now is promoted to a non-broadcast ID, it should not need to |
| 1313 | + // be promoted. |
| 1314 | + for (auto& [loop_group, current_promotion_id] : loop_promotion_map) { |
| 1315 | + if (current_promotion_id->isBroadcast()) { |
| 1316 | + continue; |
| 1317 | + } |
| 1318 | + |
| 1319 | + // As long as there's a single ID marked as broadcast only, this |
| 1320 | + // group originally consisted of broadcast IDs only. Note that, |
| 1321 | + // since new IDs were added as part of the promotion analysis, not |
| 1322 | + // all of the IDs are included in the broadcast only set. |
| 1323 | + IterDomain* original_broadcast_id = nullptr; |
| 1324 | + for (auto val : *loop_group) { |
| 1325 | + if (broadcast_only_loop_group_ids_.contains(val)) { |
| 1326 | + original_broadcast_id = val->as<IterDomain>(); |
| 1327 | + break; |
| 1328 | + } |
| 1329 | + } |
| 1330 | + if (original_broadcast_id == nullptr) { |
| 1331 | + continue; |
| 1332 | + } |
| 1333 | + |
| 1334 | + // Note that this promotion should be valid for the existing |
| 1335 | + // IDs that originate from the fusion, but the loop group also |
| 1336 | + // contains other non-broadcast IDs for loop promotion, e.g., |
| 1337 | + // current_promotion_id. This replacement means those |
| 1338 | + // non-broadcast IDs are also promoted to the broadcast ID, which |
| 1339 | + // does not make sense. For example, in the case of |
| 1340 | + // IdModelTest.BroadcastOnlyNoLoopPromotion, the innermost loop ID |
| 1341 | + // of tv2 has no mapping in the original fusion, but its loop |
| 1342 | + // group gets additional IDs, iS17 and iS19, both of which are |
| 1343 | + // exact mapped with the innermost loop IDs of tv1 and tv3. |
| 1344 | + // |
| 1345 | + // TODO: Consider cleaning up the unused non-broadcast IDs. |
| 1346 | + current_promotion_id = original_broadcast_id; |
| 1347 | + } |
| 1348 | +} |
| 1349 | + |
1296 | 1350 | } // namespace nvfuser
|
0 commit comments