diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 92a02fb59ac..d449e3267ed 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -315,6 +315,47 @@ void prepareForBackwardTransformPropagation(TensorView* ref_tv) { tvs_with_extra_transforms, ref_tv->getLogicalDomain()); } +// Partition a given set of tensors to two disjoint sets based on a +// given iter domain and reachability from the iter domain. Returns two +// vectors of tensors, first of which contains all tensors that has an +// iter domain that is reachable from the given iter domain, whereas +// the rest of tensors are all grouped into the second +// list. Reachability is determined by using the permissive BFS +// traversal on a given graph. +std::pair, std::vector> partitionTvsById( + const std::vector tvs, + IterDomain* id, + const ValGraph& graph) { + ValGroups target_groups; + for (auto tv : tvs) { + target_groups.pushBack(graph.toGroups(tv->getLogicalDomain())); + } + + const auto reachable_groups = getReachableValsFrom( + {graph.toGroup(id)}, + target_groups.vector(), + /*allowed_direction=*/Direction::Undefined, + graph); + const std::unordered_set reachable_group_set{ + reachable_groups.begin(), reachable_groups.end()}; + + std::vector reachable_tvs; + std::vector unreachable_tvs; + + for (auto tv : tvs) { + if (std::ranges::any_of( + tv->getLogicalDomain(), [&](IterDomain* logical_id) { + return reachable_group_set.contains(graph.toGroup(logical_id)); + })) { + reachable_tvs.push_back(tv); + } else { + unreachable_tvs.push_back(tv); + } + } + + return std::make_pair(reachable_tvs, unreachable_tvs); +} + } // namespace void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { @@ -384,7 +425,7 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { id_model->buildExactGraph(); // Detect an ending repeat - auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv); + auto repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv); // Just simple scheduling for now. // TODO: Do something smarter. Can just use the pointwise scheduler? @@ -425,35 +466,33 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { // detected. The repeat ID then just remains there with no // scheduling. bool repeat_id_moved_to_outermost = false; - if (static_repeat_info.has_value()) { - NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv); - auto ref_repeat_id_it = std::find_if( + if (repeat_info.has_value()) { + auto ref_factor_id_it = std::find_if( ref_tv->getLoopDomain().begin(), ref_tv->getLoopDomain().end(), [&](IterDomain* loop_id) { return id_model->idGraph(IdMappingMode::EXACT) .disjointValSets() - .strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id); + .strictAreMapped(loop_id, repeat_info->factor_id); }); - // Gives up if the repeat ID is not found. Unclear if this could - // actually happen, though. - if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) { - auto repeat_id_pos = - std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it); + // The factor ID should be found in the loop domain as the + // reshape should be cancelled at this point. Gives up if not. + if (ref_factor_id_it != ref_tv->getLoopDomain().end()) { + auto factor_id_pos = + std::distance(ref_tv->getLoopDomain().begin(), ref_factor_id_it); NVF_ERROR( - repeat_id_pos >= outermost_pos, - "Unexpected to have DID-parallelized repeat axis: ", - static_repeat_info->reshape_repeat_id->toString()); + factor_id_pos >= outermost_pos, + "Unexpected to have DID-parallelized repeat factor axis: ", + repeat_info->factor_id->toString()); - // [DID, ..., repeat_id, ...] + // [DID, ..., repeat_factor_id, ...] // ^ // +--- outermost_pos - ref_tv->reorder(std::unordered_map{{repeat_id_pos, 0}}); + ref_tv->reorder(std::unordered_map{{factor_id_pos, 0}}); ++outermost_pos; - // [repeat_id, DID, ...] - // ^ - // +--- outermost_pos - + // [repeat_factor_id, DID, ...] + // ^ + // +--- outermost_pos repeat_id_moved_to_outermost = true; } } @@ -514,34 +553,25 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { // post-repeat group, where only the latter group has the repeat // IDs. When propagating the loop domain of the reference tensor, // which has the repeat ID, the full loop domain is propagated only - // to the post-repeat group. For the pre-repeat group, the repeat ID - // is dropped and only the remaining loop domain is propagated. + // to the tensors that have IDs that are mapped with the repeat + // ID. For the rest of the tensros, the repeat ID is dropped and + // only the remaining loop domain is propagated. if (repeat_id_moved_to_outermost) { - // Divide all tvs to the pre and posgt repeat groups - auto all_tvs = fusion->allTvs(); - std::vector post_repeat_tvs; - post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size()); - std::vector pre_repeat_tvs; - pre_repeat_tvs.reserve( - all_tvs.size() - static_repeat_info->repeat_tvs.size()); - for (auto tv : all_tvs) { - if (static_repeat_info->repeat_tvs.count(tv)) { - post_repeat_tvs.push_back(tv); - } else { - pre_repeat_tvs.push_back(tv); - } - } + const auto& [tvs_with_repeat_id, tvs_without_repeat_id] = partitionTvsById( + fusion->allTvs(), + repeat_info->factor_id, + id_model->maybeBuildGraph(IdMappingMode::BROADCAST)); // The repeat ID should be located at the outermost position std::vector non_repeated_loop{ ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()}; scheduler_tools::scheduleLoopDomainsLike( - pre_repeat_tvs, + tvs_without_repeat_id, non_repeated_loop, /*update_loop_domain_only=*/true); scheduler_tools::scheduleLoopDomainsLike( - post_repeat_tvs, + tvs_with_repeat_id, ref_tv->getLoopDomain(), /*update_loop_domain_only=*/true); } else { diff --git a/csrc/scheduler/tools/static_repeat.cpp b/csrc/scheduler/tools/static_repeat.cpp index 331481aac80..a0f53276956 100644 --- a/csrc/scheduler/tools/static_repeat.cpp +++ b/csrc/scheduler/tools/static_repeat.cpp @@ -8,126 +8,40 @@ #include #include +#include #include namespace nvfuser { namespace scheduler_tools { std::optional getMaybeStaticRepeatInfo( - TensorView* maybe_repeat_out) { - // The pattern to detect: - // - // broadcast_out = broadcast(input) - // expand_out = expand(broadcast_out) - // repeat_out = reshape(expand_out) - // - // Additionally, since maybe_repeat_out is commonly a fusion - // output, it is likely there's a cache tv between expand_out and - // repeat_out, so the following pattern should also be detected. - // - // broadcast_out = broadcast(input) - // expand_out = expand(broadcast_out) - // cache_of_repeat_out = reshape(expand_out) - // repeat_out = set(cache_of_repeat_out) - - std::unordered_set repeat_tvs; - repeat_tvs.insert(maybe_repeat_out); - - auto reshape_out = maybe_repeat_out; - - // Check if there's a cache - if (auto ldst = dynamic_cast(maybe_repeat_out->definition()); - ldst != nullptr && ldst->opType() == LoadStoreOpType::Set) { - reshape_out = ldst->in()->as(); - repeat_tvs.insert(reshape_out); + TensorView* maybe_repeat_out_tv) { + // Skip set ops if any (e.g., inserted by caching). Only Set + // or SegmenterSet are considered. + while (auto ldst = + dynamic_cast(maybe_repeat_out_tv->definition())) { + if (ldst->opType() != LoadStoreOpType::Set && + ldst->opType() != LoadStoreOpType::SegmenterSet) { + break; + } + maybe_repeat_out_tv = ldst->in()->as(); } // Detect reshape - auto reshape = dynamic_cast(reshape_out->definition()); + auto reshape = dynamic_cast(maybe_repeat_out_tv->definition()); if (reshape == nullptr) { return std::nullopt; } - // Detect expand - auto expand_out = reshape->in(); - repeat_tvs.insert(expand_out); - auto expand = dynamic_cast(expand_out->definition()); - if (expand == nullptr) { - return std::nullopt; - } - - // Detect broadcast - auto broadcast_out = expand->in(); - repeat_tvs.insert(broadcast_out); - auto broadcast = dynamic_cast(broadcast_out->definition()); - if (broadcast == nullptr) { - return std::nullopt; - } - - auto inp_tv = broadcast->in(); - - // Not sure if this is really necessary to check, but assume there's - // only single chain of the ops and tensors from inp_tv to - // maybe_reshape_out - if (inp_tv->uses().size() > 1 && - std::any_of(repeat_tvs.begin(), repeat_tvs.end(), [](TensorView* tv) { - return tv->uses().size() > 1; - })) { - return std::nullopt; - } - - // Check if the ops match with the repeat pattern. Currently only - // one iter domain can be repeated - IterDomain* broadcast_id = nullptr; - int64_t broadcast_pos = -1; - for (const auto i : arange(broadcast_out->getLogicalDomain().size())) { - if (broadcast->getBroadcastDimFlags().at(i)) { - if (broadcast_id != nullptr) { - // Multiple broadcast IDs not supported - return std::nullopt; - } - broadcast_id = broadcast_out->getLogicalDomain().at(i); - broadcast_pos = (int64_t)i; - } - } + auto reshape_in = reshape->input(0)->as(); + auto reshape_out = reshape->output(0)->as(); - if (broadcast_id == nullptr) { - return std::nullopt; - } - - // Check if and only if the broadcast ID is expanded - IterDomain* expanded_id = nullptr; - for (const auto i : arange(broadcast_out->getLogicalDomain().size())) { - auto p_id = broadcast_out->getLogicalDomain().at(i); - auto c_id = expand_out->getLogicalDomain().at(i); - if (p_id == broadcast_id && c_id->isBroadcast() && - c_id->hasExpandedExtent()) { - expanded_id = c_id; - } else if ( - p_id->isBroadcast() && !p_id->hasExpandedExtent() && - c_id->isBroadcast() && c_id->hasExpandedExtent()) { - // Expanded but this broadcast was not introduced by the - // preceding broadcast op - return std::nullopt; - } - } - - if (expanded_id == nullptr) { - return std::nullopt; - } - - // Only a static repeat factor is considered - if (!expanded_id->expandedExtent()->isConstInt()) { - return std::nullopt; - } - - // The expanded ID should be merged with the iter domain next to it, - // and that should be the only reshape expr auto reshape_exprs = DependencyCheck::getAllExprsBetween( {reshape_out->getRootDomain().begin(), reshape_out->getRootDomain().end()}, {reshape_out->getLogicalDomain().begin(), reshape_out->getLogicalDomain().end()}); + if (reshape_exprs.size() != 1) { return std::nullopt; } @@ -137,14 +51,6 @@ std::optional getMaybeStaticRepeatInfo( return std::nullopt; } - // The corresponding root ID of the outout tv should be one of the - // inputs of the merge - auto reshape_root_broadcast = reshape_out->getRootDomain().at(broadcast_pos); - if (reshape_merge->outer() != reshape_root_broadcast && - reshape_merge->inner() != reshape_root_broadcast) { - return std::nullopt; - } - // Reshape of an expanded broadcast always generates a concrete // non-broadcast ID, so this check is not necessary, but just in // case in the future that may change. @@ -154,10 +60,39 @@ std::optional getMaybeStaticRepeatInfo( } StaticRepeatInfo info; - info.repeat_output_tv = maybe_repeat_out; - info.reshape_output_tv = reshape_out; - info.reshape_repeat_id = reshape_out->getRootDomain().at(broadcast_pos); - info.repeat_tvs = repeat_tvs; + + info.output_id = reshape_merge->out(); + + const auto c2p = + PairwiseLogicalDomainMap(reshape_in, reshape_out).mapConsumerToProducer(); + + auto producer_merge_outer = c2p.at(reshape_merge->outer()); + auto producer_merge_inner = c2p.at(reshape_merge->inner()); + IterDomain* producer_factor_id = nullptr; + + if (producer_merge_outer->isBroadcast() && + producer_merge_outer->hasExpandedExtent() && + !producer_merge_inner->isBroadcast()) { + // Inner ID is repeated by the factor of the outer extent + info.input_id = reshape_merge->inner(); + info.factor_id = reshape_merge->outer(); + producer_factor_id = producer_merge_outer; + } else if ( + producer_merge_inner->isBroadcast() && + producer_merge_inner->hasExpandedExtent() && + !producer_merge_outer->isBroadcast()) { + // Outer ID is repeated by the factor of the inner extent + info.input_id = reshape_merge->outer(); + info.factor_id = reshape_merge->inner(); + producer_factor_id = producer_merge_inner; + } else { + return std::nullopt; + } + + // Check if the expanded ID has a static expanded extent + if (!producer_factor_id->expandedExtent()->isConstInt()) { + return std::nullopt; + } return info; } diff --git a/csrc/scheduler/tools/static_repeat.h b/csrc/scheduler/tools/static_repeat.h index bfdd8d4f346..f20c02c78fd 100644 --- a/csrc/scheduler/tools/static_repeat.h +++ b/csrc/scheduler/tools/static_repeat.h @@ -34,7 +34,7 @@ namespace scheduler_tools { // repeated at the end of its computation. // // This can be problematic since the whole segment is scheduled based -// on the repeated tensor whose size is largere than the rest of the +// on the repeated tensor whose size is larger than the rest of the // tensors by the repetition factor. For example, if the it is // repeated twice, we would launch threads and blocks that are // required for the twice-larger tensor but most of the actual @@ -53,20 +53,18 @@ namespace scheduler_tools { // TODO: Consider generalizing this heuristics to the other // schedulers. +// Some of the relevant iter domains of the output tensor of the +// reshape that realizes a repetition. struct StaticRepeatInfo { - // The final output tensor of the detected repeat pattern, e.g., - // t3 in the above example case. - TensorView* repeat_output_tv = nullptr; - // The reshape output tensor, e.g., t3 in the above example case. It - // is not the same as repeat_output_tv when there's a cache. - TensorView* reshape_output_tv = nullptr; - // The ID of reshape output TV that corresponds to the - // expanded broadcast ID. In the above example case, this - // would be the root ID of t3 that corresponds to b2 - IterDomain* reshape_repeat_id = nullptr; - // Output tensors of the detected broadcast, expand and reshape - // ops. In the above example case, this would consist of t1, t2 and t3. - std::unordered_set repeat_tvs; + // Root ID that is repeated. In the above example, this corresponds + // to i1. + IterDomain* input_id = nullptr; + // Root ID that is originally an expanded broadcast. In the above example, + // this corresponds to b(2). + IterDomain* factor_id = nullptr; + // Logical repeated ID. In the above example, this corresponds + // to 2*i1. + IterDomain* output_id = nullptr; }; // Check if the given tensor matches with the final reshape output diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index d4449c38560..0c37a2213ab 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -1631,7 +1631,8 @@ TEST_F(RopeTest, EndingRepeat) { auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}); auto tv2 = repeat(tv1, {2, 1}); - fusion.addOutput(tv2); + auto tv3 = segment_set(tv2); + fusion.addOutput(tv3); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape1, options); @@ -1696,4 +1697,71 @@ TEST_F(RopeTest, EndingRepeat) { } } +// Similar to EndingRepeat but with a broadcast ID already found in an +// input tensor. A similar Pattern appears in the LitGPT Llama RoPE +// module. +TEST_F(RopeTest, EndingRepeatWithNoBroadcastOp) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + std::vector shape1{3, 1, 200}; + + auto tv0 = makeContigConcreteTensor(shape1); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}); + auto tv2 = expand( + tv1, + {IrBuilder::create(-1), + IrBuilder::create(2), + IrBuilder::create(-1)}); + auto tv3 = + reshape(tv2, {IrBuilder::create(6), IrBuilder::create(-1)}); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize); + Fusion* scheduled_fusion = runtime->executors() + .at(0) + ->as() + ->compiledKernel() + ->kernel(); + + // Similar to the EndingRepeat tensor, the repeat factor ID should + // be placed at the outermost position. + auto ref_tv = scheduled_fusion->outputs().at(0)->as(); + // The outermost loop ID should be a Serial ID with an extent of 2. + EXPECT_EQ( + ref_tv->getLoopDomain().at(0)->getParallelType(), ParallelType::Serial); + EXPECT_TRUE(ref_tv->getLoopDomain().at(0)->extent()->isConstInt()); + EXPECT_EQ( + ref_tv->getLoopDomain().at(0)->extent()->evaluate().as(), 2L); + + IdModel id_model(scheduled_fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + + const auto ref_loop = exact_graph.toGroups(ref_tv->getLoopDomain()); + + // All of the tensors have a mapped ID as the factor ID, so they + // should all have the same loop ID groups. + for (auto tv : scheduled_fusion->allTvs()) { + if (tv->isFusionInput()) { + continue; + } + EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop); + EXPECT_EQ(tv->getLoopDomain().size(), tv->getComputeAtPosition()); + } +} + } // namespace nvfuser