Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 67 additions & 37 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,47 @@ void prepareForBackwardTransformPropagation(TensorView* ref_tv) {
tvs_with_extra_transforms, ref_tv->getLogicalDomain());
}

// Partition a given set of tensors to two disjoint sets based on a
// given iter domain and reachability from the iter domain. Returns two
// vectors of tensors, first of which contains all tensors that has an
// iter domain that is reachable from the given iter domain, whereas
// the rest of tensors are all grouped into the second
// list. Reachability is determined by using the permissive BFS
// traversal on a given graph.
std::pair<std::vector<TensorView*>, std::vector<TensorView*>> partitionTvsById(
const std::vector<TensorView*> tvs,
IterDomain* id,
const ValGraph& graph) {
ValGroups target_groups;
for (auto tv : tvs) {
target_groups.pushBack(graph.toGroups(tv->getLogicalDomain()));
}

const auto reachable_groups = getReachableValsFrom<ValGraphPermissiveBFS>(
{graph.toGroup(id)},
target_groups.vector(),
/*allowed_direction=*/Direction::Undefined,
graph);
const std::unordered_set<ValGroup> reachable_group_set{
reachable_groups.begin(), reachable_groups.end()};

std::vector<TensorView*> reachable_tvs;
std::vector<TensorView*> unreachable_tvs;

for (auto tv : tvs) {
if (std::ranges::any_of(
tv->getLogicalDomain(), [&](IterDomain* logical_id) {
return reachable_group_set.contains(graph.toGroup(logical_id));
})) {
reachable_tvs.push_back(tv);
} else {
unreachable_tvs.push_back(tv);
}
}

return std::make_pair(reachable_tvs, unreachable_tvs);
}

} // namespace

void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
Expand Down Expand Up @@ -384,7 +425,7 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
id_model->buildExactGraph();

// Detect an ending repeat
auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv);
auto repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv);

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?
Expand Down Expand Up @@ -425,35 +466,33 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// detected. The repeat ID then just remains there with no
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move the def of repeat_info here - near the use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep it there as ref_tv is going to be transformed after that. It shouldn't affect the analysis, but there's no need to introduce an additional complexity.

// scheduling.
bool repeat_id_moved_to_outermost = false;
if (static_repeat_info.has_value()) {
NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv);
auto ref_repeat_id_it = std::find_if(
if (repeat_info.has_value()) {
auto ref_factor_id_it = std::find_if(
ref_tv->getLoopDomain().begin(),
ref_tv->getLoopDomain().end(),
[&](IterDomain* loop_id) {
return id_model->idGraph(IdMappingMode::EXACT)
.disjointValSets()
.strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id);
.strictAreMapped(loop_id, repeat_info->factor_id);
});
// Gives up if the repeat ID is not found. Unclear if this could
// actually happen, though.
if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) {
auto repeat_id_pos =
std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it);
// The factor ID should be found in the loop domain as the
// reshape should be cancelled at this point. Gives up if not.
if (ref_factor_id_it != ref_tv->getLoopDomain().end()) {
auto factor_id_pos =
std::distance(ref_tv->getLoopDomain().begin(), ref_factor_id_it);
NVF_ERROR(
repeat_id_pos >= outermost_pos,
"Unexpected to have DID-parallelized repeat axis: ",
static_repeat_info->reshape_repeat_id->toString());
factor_id_pos >= outermost_pos,
"Unexpected to have DID-parallelized repeat factor axis: ",
repeat_info->factor_id->toString());

// [DID, ..., repeat_id, ...]
// [DID, ..., repeat_factor_id, ...]
// ^
// +--- outermost_pos
ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{repeat_id_pos, 0}});
ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{factor_id_pos, 0}});
++outermost_pos;
// [repeat_id, DID, ...]
// ^
// +--- outermost_pos

// [repeat_factor_id, DID, ...]
// ^
// +--- outermost_pos
repeat_id_moved_to_outermost = true;
}
}
Expand Down Expand Up @@ -514,34 +553,25 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// post-repeat group, where only the latter group has the repeat
// IDs. When propagating the loop domain of the reference tensor,
// which has the repeat ID, the full loop domain is propagated only
// to the post-repeat group. For the pre-repeat group, the repeat ID
// is dropped and only the remaining loop domain is propagated.
// to the tensors that have IDs that are mapped with the repeat
// ID. For the rest of the tensros, the repeat ID is dropped and
// only the remaining loop domain is propagated.
if (repeat_id_moved_to_outermost) {
// Divide all tvs to the pre and posgt repeat groups
auto all_tvs = fusion->allTvs();
std::vector<TensorView*> post_repeat_tvs;
post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size());
std::vector<TensorView*> pre_repeat_tvs;
pre_repeat_tvs.reserve(
all_tvs.size() - static_repeat_info->repeat_tvs.size());
for (auto tv : all_tvs) {
if (static_repeat_info->repeat_tvs.count(tv)) {
post_repeat_tvs.push_back(tv);
} else {
pre_repeat_tvs.push_back(tv);
}
}
const auto& [tvs_with_repeat_id, tvs_without_repeat_id] = partitionTvsById(
fusion->allTvs(),
repeat_info->factor_id,
id_model->maybeBuildGraph(IdMappingMode::BROADCAST));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive question for my own curiosity, do we need to map BROADCAST?

In the added example

1708   std::vector<int64_t> shape1{3, 1, 200};
1709 
1710   auto tv0 = makeContigConcreteTensor(shape1);
1711   fusion.addInput(tv0);
1712 
1713   auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
1714   auto tv2 = expand(
1715       tv1,
1716       {IrBuilder::create<Val>(-1),
1717        IrBuilder::create<Val>(2),
1718        IrBuilder::create<Val>(-1)});
1719   auto tv3 =
1720       reshape(tv2, {IrBuilder::create<Val>(6), IrBuilder::create<Val>(-1)});
1721   fusion.addOutput(tv3);

Say for tv1 [i0, b(1), i2], after the expand, we would have tv2 [i0, b(2), i2]
The two broadcast ID in tv1 and tv2 would have different extent.

Q1. IIUC, mapping with broadcast would allow us map those two together?
Q2. Does it matter for us to group tv1 with the tvs_with_repeat_id, even though the it only contains the non-expanded factor_id?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXACT should work too. Previously, we schedule those tensors like tv1 and tv2 together with tv3, so using BROADCAST keeps the same behavior. I don't think there should be any actual difference in final performances.


// The repeat ID should be located at the outermost position
std::vector<IterDomain*> non_repeated_loop{
ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()};

scheduler_tools::scheduleLoopDomainsLike(
pre_repeat_tvs,
tvs_without_repeat_id,
non_repeated_loop,
/*update_loop_domain_only=*/true);
scheduler_tools::scheduleLoopDomainsLike(
post_repeat_tvs,
tvs_with_repeat_id,
ref_tv->getLoopDomain(),
/*update_loop_domain_only=*/true);
} else {
Expand Down
161 changes: 48 additions & 113 deletions csrc/scheduler/tools/static_repeat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,126 +8,40 @@

#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <logical_domain_map.h>
#include <scheduler/tools/static_repeat.h>

namespace nvfuser {
namespace scheduler_tools {

std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
TensorView* maybe_repeat_out) {
// The pattern to detect:
//
// broadcast_out = broadcast(input)
// expand_out = expand(broadcast_out)
// repeat_out = reshape(expand_out)
//
// Additionally, since maybe_repeat_out is commonly a fusion
// output, it is likely there's a cache tv between expand_out and
// repeat_out, so the following pattern should also be detected.
//
// broadcast_out = broadcast(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 4 lines were helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure which 4 lines, but broadcast is no longer required.

// expand_out = expand(broadcast_out)
// cache_of_repeat_out = reshape(expand_out)
// repeat_out = set(cache_of_repeat_out)

std::unordered_set<TensorView*> repeat_tvs;
repeat_tvs.insert(maybe_repeat_out);

auto reshape_out = maybe_repeat_out;

// Check if there's a cache
if (auto ldst = dynamic_cast<LoadStoreOp*>(maybe_repeat_out->definition());
ldst != nullptr && ldst->opType() == LoadStoreOpType::Set) {
reshape_out = ldst->in()->as<TensorView>();
repeat_tvs.insert(reshape_out);
TensorView* maybe_repeat_out_tv) {
// Skip set ops if any (e.g., inserted by caching). Only Set
Copy link
Collaborator

@protonu protonu Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the deleted comment here was helpful - the bit about skipping caching ops.

// or SegmenterSet are considered.
while (auto ldst =
dynamic_cast<LoadStoreOp*>(maybe_repeat_out_tv->definition())) {
if (ldst->opType() != LoadStoreOpType::Set &&
ldst->opType() != LoadStoreOpType::SegmenterSet) {
break;
}
maybe_repeat_out_tv = ldst->in()->as<TensorView>();
}

// Detect reshape
auto reshape = dynamic_cast<ViewOp*>(reshape_out->definition());
auto reshape = dynamic_cast<ViewOp*>(maybe_repeat_out_tv->definition());
if (reshape == nullptr) {
return std::nullopt;
}

// Detect expand
auto expand_out = reshape->in();
repeat_tvs.insert(expand_out);
auto expand = dynamic_cast<ExpandOp*>(expand_out->definition());
if (expand == nullptr) {
return std::nullopt;
}

// Detect broadcast
auto broadcast_out = expand->in();
repeat_tvs.insert(broadcast_out);
auto broadcast = dynamic_cast<BroadcastOp*>(broadcast_out->definition());
if (broadcast == nullptr) {
return std::nullopt;
}

auto inp_tv = broadcast->in();

// Not sure if this is really necessary to check, but assume there's
// only single chain of the ops and tensors from inp_tv to
// maybe_reshape_out
if (inp_tv->uses().size() > 1 &&
std::any_of(repeat_tvs.begin(), repeat_tvs.end(), [](TensorView* tv) {
return tv->uses().size() > 1;
})) {
return std::nullopt;
}

// Check if the ops match with the repeat pattern. Currently only
// one iter domain can be repeated
IterDomain* broadcast_id = nullptr;
int64_t broadcast_pos = -1;
for (const auto i : arange(broadcast_out->getLogicalDomain().size())) {
if (broadcast->getBroadcastDimFlags().at(i)) {
if (broadcast_id != nullptr) {
// Multiple broadcast IDs not supported
return std::nullopt;
}
broadcast_id = broadcast_out->getLogicalDomain().at(i);
broadcast_pos = (int64_t)i;
}
}
auto reshape_in = reshape->input(0)->as<TensorView>();
auto reshape_out = reshape->output(0)->as<TensorView>();

if (broadcast_id == nullptr) {
return std::nullopt;
}

// Check if and only if the broadcast ID is expanded
IterDomain* expanded_id = nullptr;
for (const auto i : arange(broadcast_out->getLogicalDomain().size())) {
auto p_id = broadcast_out->getLogicalDomain().at(i);
auto c_id = expand_out->getLogicalDomain().at(i);
if (p_id == broadcast_id && c_id->isBroadcast() &&
c_id->hasExpandedExtent()) {
expanded_id = c_id;
} else if (
p_id->isBroadcast() && !p_id->hasExpandedExtent() &&
c_id->isBroadcast() && c_id->hasExpandedExtent()) {
// Expanded but this broadcast was not introduced by the
// preceding broadcast op
return std::nullopt;
}
}

if (expanded_id == nullptr) {
return std::nullopt;
}

// Only a static repeat factor is considered
if (!expanded_id->expandedExtent()->isConstInt()) {
return std::nullopt;
}

// The expanded ID should be merged with the iter domain next to it,
// and that should be the only reshape expr
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});

if (reshape_exprs.size() != 1) {
return std::nullopt;
}
Expand All @@ -137,14 +51,6 @@ std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
return std::nullopt;
}

// The corresponding root ID of the outout tv should be one of the
// inputs of the merge
auto reshape_root_broadcast = reshape_out->getRootDomain().at(broadcast_pos);
if (reshape_merge->outer() != reshape_root_broadcast &&
reshape_merge->inner() != reshape_root_broadcast) {
return std::nullopt;
}

// Reshape of an expanded broadcast always generates a concrete
// non-broadcast ID, so this check is not necessary, but just in
// case in the future that may change.
Expand All @@ -154,10 +60,39 @@ std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
}

StaticRepeatInfo info;
info.repeat_output_tv = maybe_repeat_out;
info.reshape_output_tv = reshape_out;
info.reshape_repeat_id = reshape_out->getRootDomain().at(broadcast_pos);
info.repeat_tvs = repeat_tvs;

info.output_id = reshape_merge->out();

const auto c2p =
PairwiseLogicalDomainMap(reshape_in, reshape_out).mapConsumerToProducer();

auto producer_merge_outer = c2p.at(reshape_merge->outer());
auto producer_merge_inner = c2p.at(reshape_merge->inner());
IterDomain* producer_factor_id = nullptr;

if (producer_merge_outer->isBroadcast() &&
producer_merge_outer->hasExpandedExtent() &&
!producer_merge_inner->isBroadcast()) {
// Inner ID is repeated by the factor of the outer extent
info.input_id = reshape_merge->inner();
info.factor_id = reshape_merge->outer();
producer_factor_id = producer_merge_outer;
} else if (
producer_merge_inner->isBroadcast() &&
producer_merge_inner->hasExpandedExtent() &&
!producer_merge_outer->isBroadcast()) {
// Outer ID is repeated by the factor of the inner extent
info.input_id = reshape_merge->outer();
info.factor_id = reshape_merge->inner();
producer_factor_id = producer_merge_inner;
} else {
return std::nullopt;
}

// Check if the expanded ID has a static expanded extent
if (!producer_factor_id->expandedExtent()->isConstInt()) {
return std::nullopt;
}

return info;
}
Expand Down
26 changes: 12 additions & 14 deletions csrc/scheduler/tools/static_repeat.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace scheduler_tools {
// repeated at the end of its computation.
//
// This can be problematic since the whole segment is scheduled based
// on the repeated tensor whose size is largere than the rest of the
// on the repeated tensor whose size is larger than the rest of the
// tensors by the repetition factor. For example, if the it is
// repeated twice, we would launch threads and blocks that are
// required for the twice-larger tensor but most of the actual
Expand All @@ -53,20 +53,18 @@ namespace scheduler_tools {
// TODO: Consider generalizing this heuristics to the other
// schedulers.

// Some of the relevant iter domains of the output tensor of the
// reshape that realizes a repetition.
struct StaticRepeatInfo {
// The final output tensor of the detected repeat pattern, e.g.,
// t3 in the above example case.
TensorView* repeat_output_tv = nullptr;
// The reshape output tensor, e.g., t3 in the above example case. It
// is not the same as repeat_output_tv when there's a cache.
TensorView* reshape_output_tv = nullptr;
// The ID of reshape output TV that corresponds to the
// expanded broadcast ID. In the above example case, this
// would be the root ID of t3 that corresponds to b2
IterDomain* reshape_repeat_id = nullptr;
// Output tensors of the detected broadcast, expand and reshape
// ops. In the above example case, this would consist of t1, t2 and t3.
std::unordered_set<TensorView*> repeat_tvs;
// Root ID that is repeated. In the above example, this corresponds
// to i1.
IterDomain* input_id = nullptr;
// Root ID that is originally an expanded broadcast. In the above example,
// this corresponds to b(2).
IterDomain* factor_id = nullptr;
// Logical repeated ID. In the above example, this corresponds
// to 2*i1.
IterDomain* output_id = nullptr;
};

// Check if the given tensor matches with the final reshape output
Expand Down
Loading