-
Notifications
You must be signed in to change notification settings - Fork 79
Cleaning up scheduling of static repeat #4325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
447c23a
3a4f62d
13965fe
bdbcaa8
a6e3a20
14fcc45
a5aaebf
3c9c249
300da7a
e803fa0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -315,6 +315,47 @@ void prepareForBackwardTransformPropagation(TensorView* ref_tv) { | |
| tvs_with_extra_transforms, ref_tv->getLogicalDomain()); | ||
| } | ||
|
|
||
| // Partition a given set of tensors to two disjoint sets based on a | ||
| // given iter domain and reachability from the iter domain. Returns two | ||
| // vectors of tensors, first of which contains all tensors that has an | ||
| // iter domain that is reachable from the given iter domain, whereas | ||
| // the rest of tensors are all grouped into the second | ||
| // list. Reachability is determined by using the permissive BFS | ||
| // traversal on a given graph. | ||
| std::pair<std::vector<TensorView*>, std::vector<TensorView*>> partitionTvsById( | ||
| const std::vector<TensorView*> tvs, | ||
| IterDomain* id, | ||
| const ValGraph& graph) { | ||
| ValGroups target_groups; | ||
| for (auto tv : tvs) { | ||
| target_groups.pushBack(graph.toGroups(tv->getLogicalDomain())); | ||
| } | ||
|
|
||
| const auto reachable_groups = getReachableValsFrom<ValGraphPermissiveBFS>( | ||
| {graph.toGroup(id)}, | ||
| target_groups.vector(), | ||
| /*allowed_direction=*/Direction::Undefined, | ||
| graph); | ||
| const std::unordered_set<ValGroup> reachable_group_set{ | ||
| reachable_groups.begin(), reachable_groups.end()}; | ||
|
|
||
| std::vector<TensorView*> reachable_tvs; | ||
| std::vector<TensorView*> unreachable_tvs; | ||
|
|
||
| for (auto tv : tvs) { | ||
| if (std::ranges::any_of( | ||
| tv->getLogicalDomain(), [&](IterDomain* logical_id) { | ||
| return reachable_group_set.contains(graph.toGroup(logical_id)); | ||
| })) { | ||
| reachable_tvs.push_back(tv); | ||
| } else { | ||
| unreachable_tvs.push_back(tv); | ||
| } | ||
| } | ||
|
|
||
| return std::make_pair(reachable_tvs, unreachable_tvs); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | ||
|
|
@@ -384,7 +425,7 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |
| id_model->buildExactGraph(); | ||
|
|
||
| // Detect an ending repeat | ||
| auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv); | ||
| auto repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv); | ||
|
|
||
| // Just simple scheduling for now. | ||
| // TODO: Do something smarter. Can just use the pointwise scheduler? | ||
|
|
@@ -425,35 +466,33 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |
| // detected. The repeat ID then just remains there with no | ||
| // scheduling. | ||
| bool repeat_id_moved_to_outermost = false; | ||
| if (static_repeat_info.has_value()) { | ||
| NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv); | ||
| auto ref_repeat_id_it = std::find_if( | ||
| if (repeat_info.has_value()) { | ||
| auto ref_factor_id_it = std::find_if( | ||
| ref_tv->getLoopDomain().begin(), | ||
| ref_tv->getLoopDomain().end(), | ||
| [&](IterDomain* loop_id) { | ||
| return id_model->idGraph(IdMappingMode::EXACT) | ||
| .disjointValSets() | ||
| .strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id); | ||
| .strictAreMapped(loop_id, repeat_info->factor_id); | ||
| }); | ||
| // Gives up if the repeat ID is not found. Unclear if this could | ||
| // actually happen, though. | ||
| if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) { | ||
| auto repeat_id_pos = | ||
| std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it); | ||
| // The factor ID should be found in the loop domain as the | ||
| // reshape should be cancelled at this point. Gives up if not. | ||
| if (ref_factor_id_it != ref_tv->getLoopDomain().end()) { | ||
| auto factor_id_pos = | ||
| std::distance(ref_tv->getLoopDomain().begin(), ref_factor_id_it); | ||
| NVF_ERROR( | ||
| repeat_id_pos >= outermost_pos, | ||
| "Unexpected to have DID-parallelized repeat axis: ", | ||
| static_repeat_info->reshape_repeat_id->toString()); | ||
| factor_id_pos >= outermost_pos, | ||
| "Unexpected to have DID-parallelized repeat factor axis: ", | ||
| repeat_info->factor_id->toString()); | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // [DID, ..., repeat_id, ...] | ||
| // [DID, ..., repeat_factor_id, ...] | ||
| // ^ | ||
| // +--- outermost_pos | ||
| ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{repeat_id_pos, 0}}); | ||
| ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{factor_id_pos, 0}}); | ||
| ++outermost_pos; | ||
| // [repeat_id, DID, ...] | ||
| // ^ | ||
| // +--- outermost_pos | ||
|
|
||
| // [repeat_factor_id, DID, ...] | ||
| // ^ | ||
| // +--- outermost_pos | ||
| repeat_id_moved_to_outermost = true; | ||
| } | ||
| } | ||
|
|
@@ -514,34 +553,25 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |
| // post-repeat group, where only the latter group has the repeat | ||
| // IDs. When propagating the loop domain of the reference tensor, | ||
| // which has the repeat ID, the full loop domain is propagated only | ||
| // to the post-repeat group. For the pre-repeat group, the repeat ID | ||
| // is dropped and only the remaining loop domain is propagated. | ||
| // to the tensors that have IDs that are mapped with the repeat | ||
| // ID. For the rest of the tensros, the repeat ID is dropped and | ||
| // only the remaining loop domain is propagated. | ||
| if (repeat_id_moved_to_outermost) { | ||
| // Divide all tvs to the pre and posgt repeat groups | ||
| auto all_tvs = fusion->allTvs(); | ||
| std::vector<TensorView*> post_repeat_tvs; | ||
| post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size()); | ||
| std::vector<TensorView*> pre_repeat_tvs; | ||
| pre_repeat_tvs.reserve( | ||
| all_tvs.size() - static_repeat_info->repeat_tvs.size()); | ||
| for (auto tv : all_tvs) { | ||
| if (static_repeat_info->repeat_tvs.count(tv)) { | ||
| post_repeat_tvs.push_back(tv); | ||
| } else { | ||
| pre_repeat_tvs.push_back(tv); | ||
| } | ||
| } | ||
| const auto& [tvs_with_repeat_id, tvs_without_repeat_id] = partitionTvsById( | ||
| fusion->allTvs(), | ||
| repeat_info->factor_id, | ||
| id_model->maybeBuildGraph(IdMappingMode::BROADCAST)); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naive question for my own curiosity, do we need to map In the added example Say for Q1. IIUC, mapping with broadcast would allow us map those two together?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| // The repeat ID should be located at the outermost position | ||
| std::vector<IterDomain*> non_repeated_loop{ | ||
| ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()}; | ||
|
|
||
| scheduler_tools::scheduleLoopDomainsLike( | ||
| pre_repeat_tvs, | ||
| tvs_without_repeat_id, | ||
| non_repeated_loop, | ||
| /*update_loop_domain_only=*/true); | ||
| scheduler_tools::scheduleLoopDomainsLike( | ||
| post_repeat_tvs, | ||
| tvs_with_repeat_id, | ||
| ref_tv->getLoopDomain(), | ||
| /*update_loop_domain_only=*/true); | ||
| } else { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,126 +8,40 @@ | |
|
|
||
| #include <ir/all_nodes.h> | ||
| #include <ir/utils.h> | ||
| #include <logical_domain_map.h> | ||
| #include <scheduler/tools/static_repeat.h> | ||
|
|
||
| namespace nvfuser { | ||
| namespace scheduler_tools { | ||
|
|
||
| std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo( | ||
| TensorView* maybe_repeat_out) { | ||
| // The pattern to detect: | ||
| // | ||
| // broadcast_out = broadcast(input) | ||
| // expand_out = expand(broadcast_out) | ||
| // repeat_out = reshape(expand_out) | ||
| // | ||
| // Additionally, since maybe_repeat_out is commonly a fusion | ||
| // output, it is likely there's a cache tv between expand_out and | ||
| // repeat_out, so the following pattern should also be detected. | ||
| // | ||
| // broadcast_out = broadcast(input) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These 4 lines were helpful.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure which 4 lines, but broadcast is no longer required. |
||
| // expand_out = expand(broadcast_out) | ||
| // cache_of_repeat_out = reshape(expand_out) | ||
| // repeat_out = set(cache_of_repeat_out) | ||
|
|
||
| std::unordered_set<TensorView*> repeat_tvs; | ||
| repeat_tvs.insert(maybe_repeat_out); | ||
|
|
||
| auto reshape_out = maybe_repeat_out; | ||
|
|
||
| // Check if there's a cache | ||
| if (auto ldst = dynamic_cast<LoadStoreOp*>(maybe_repeat_out->definition()); | ||
| ldst != nullptr && ldst->opType() == LoadStoreOpType::Set) { | ||
| reshape_out = ldst->in()->as<TensorView>(); | ||
| repeat_tvs.insert(reshape_out); | ||
| TensorView* maybe_repeat_out_tv) { | ||
| // Skip set ops if any (e.g., inserted by caching). Only Set | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the deleted comment here was helpful - the bit about skipping caching ops. |
||
| // or SegmenterSet are considered. | ||
| while (auto ldst = | ||
| dynamic_cast<LoadStoreOp*>(maybe_repeat_out_tv->definition())) { | ||
| if (ldst->opType() != LoadStoreOpType::Set && | ||
| ldst->opType() != LoadStoreOpType::SegmenterSet) { | ||
| break; | ||
| } | ||
| maybe_repeat_out_tv = ldst->in()->as<TensorView>(); | ||
| } | ||
|
|
||
| // Detect reshape | ||
| auto reshape = dynamic_cast<ViewOp*>(reshape_out->definition()); | ||
| auto reshape = dynamic_cast<ViewOp*>(maybe_repeat_out_tv->definition()); | ||
| if (reshape == nullptr) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Detect expand | ||
| auto expand_out = reshape->in(); | ||
| repeat_tvs.insert(expand_out); | ||
| auto expand = dynamic_cast<ExpandOp*>(expand_out->definition()); | ||
| if (expand == nullptr) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Detect broadcast | ||
| auto broadcast_out = expand->in(); | ||
| repeat_tvs.insert(broadcast_out); | ||
| auto broadcast = dynamic_cast<BroadcastOp*>(broadcast_out->definition()); | ||
| if (broadcast == nullptr) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| auto inp_tv = broadcast->in(); | ||
|
|
||
| // Not sure if this is really necessary to check, but assume there's | ||
| // only single chain of the ops and tensors from inp_tv to | ||
| // maybe_reshape_out | ||
| if (inp_tv->uses().size() > 1 && | ||
| std::any_of(repeat_tvs.begin(), repeat_tvs.end(), [](TensorView* tv) { | ||
| return tv->uses().size() > 1; | ||
| })) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Check if the ops match with the repeat pattern. Currently only | ||
| // one iter domain can be repeated | ||
| IterDomain* broadcast_id = nullptr; | ||
| int64_t broadcast_pos = -1; | ||
| for (const auto i : arange(broadcast_out->getLogicalDomain().size())) { | ||
| if (broadcast->getBroadcastDimFlags().at(i)) { | ||
| if (broadcast_id != nullptr) { | ||
| // Multiple broadcast IDs not supported | ||
| return std::nullopt; | ||
| } | ||
| broadcast_id = broadcast_out->getLogicalDomain().at(i); | ||
| broadcast_pos = (int64_t)i; | ||
| } | ||
| } | ||
| auto reshape_in = reshape->input(0)->as<TensorView>(); | ||
| auto reshape_out = reshape->output(0)->as<TensorView>(); | ||
|
|
||
| if (broadcast_id == nullptr) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Check if and only if the broadcast ID is expanded | ||
| IterDomain* expanded_id = nullptr; | ||
| for (const auto i : arange(broadcast_out->getLogicalDomain().size())) { | ||
| auto p_id = broadcast_out->getLogicalDomain().at(i); | ||
| auto c_id = expand_out->getLogicalDomain().at(i); | ||
| if (p_id == broadcast_id && c_id->isBroadcast() && | ||
| c_id->hasExpandedExtent()) { | ||
| expanded_id = c_id; | ||
| } else if ( | ||
| p_id->isBroadcast() && !p_id->hasExpandedExtent() && | ||
| c_id->isBroadcast() && c_id->hasExpandedExtent()) { | ||
| // Expanded but this broadcast was not introduced by the | ||
| // preceding broadcast op | ||
| return std::nullopt; | ||
| } | ||
| } | ||
|
|
||
| if (expanded_id == nullptr) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Only a static repeat factor is considered | ||
| if (!expanded_id->expandedExtent()->isConstInt()) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // The expanded ID should be merged with the iter domain next to it, | ||
| // and that should be the only reshape expr | ||
| auto reshape_exprs = DependencyCheck::getAllExprsBetween( | ||
| {reshape_out->getRootDomain().begin(), | ||
| reshape_out->getRootDomain().end()}, | ||
| {reshape_out->getLogicalDomain().begin(), | ||
| reshape_out->getLogicalDomain().end()}); | ||
|
|
||
| if (reshape_exprs.size() != 1) { | ||
| return std::nullopt; | ||
| } | ||
|
|
@@ -137,14 +51,6 @@ std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo( | |
| return std::nullopt; | ||
| } | ||
|
|
||
| // The corresponding root ID of the outout tv should be one of the | ||
| // inputs of the merge | ||
| auto reshape_root_broadcast = reshape_out->getRootDomain().at(broadcast_pos); | ||
| if (reshape_merge->outer() != reshape_root_broadcast && | ||
| reshape_merge->inner() != reshape_root_broadcast) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Reshape of an expanded broadcast always generates a concrete | ||
| // non-broadcast ID, so this check is not necessary, but just in | ||
| // case in the future that may change. | ||
|
|
@@ -154,10 +60,39 @@ std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo( | |
| } | ||
|
|
||
| StaticRepeatInfo info; | ||
| info.repeat_output_tv = maybe_repeat_out; | ||
| info.reshape_output_tv = reshape_out; | ||
| info.reshape_repeat_id = reshape_out->getRootDomain().at(broadcast_pos); | ||
| info.repeat_tvs = repeat_tvs; | ||
|
|
||
| info.output_id = reshape_merge->out(); | ||
|
|
||
| const auto c2p = | ||
| PairwiseLogicalDomainMap(reshape_in, reshape_out).mapConsumerToProducer(); | ||
|
|
||
| auto producer_merge_outer = c2p.at(reshape_merge->outer()); | ||
| auto producer_merge_inner = c2p.at(reshape_merge->inner()); | ||
| IterDomain* producer_factor_id = nullptr; | ||
|
|
||
| if (producer_merge_outer->isBroadcast() && | ||
| producer_merge_outer->hasExpandedExtent() && | ||
| !producer_merge_inner->isBroadcast()) { | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Inner ID is repeated by the factor of the outer extent | ||
| info.input_id = reshape_merge->inner(); | ||
| info.factor_id = reshape_merge->outer(); | ||
| producer_factor_id = producer_merge_outer; | ||
| } else if ( | ||
| producer_merge_inner->isBroadcast() && | ||
| producer_merge_inner->hasExpandedExtent() && | ||
| !producer_merge_outer->isBroadcast()) { | ||
| // Outer ID is repeated by the factor of the inner extent | ||
| info.input_id = reshape_merge->outer(); | ||
| info.factor_id = reshape_merge->inner(); | ||
| producer_factor_id = producer_merge_inner; | ||
| } else { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| // Check if the expanded ID has a static expanded extent | ||
| if (!producer_factor_id->expandedExtent()->isConstInt()) { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| return info; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move the def of
repeat_infohere - near the use.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd keep it there as
ref_tvis going to be transformed after that. It shouldn't affect the analysis, but there's no need to introduce an additional complexity.