Skip to content

Cleaning up scheduling of static repeat#4325

Merged
naoyam merged 10 commits intomainfrom
static_repeat_cleanup
Apr 30, 2025
Merged

Cleaning up scheduling of static repeat#4325
naoyam merged 10 commits intomainfrom
static_repeat_cleanup

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Apr 26, 2025

The resize scheduler automatically detects a sequence of ops to repeat a tensor at a certain ID and slightly modifies the scheduling to reduce redundant computations. This PR makes the analysis a little more flexible so that it also works with a pattern appearing in a Llama forward RoPE module.

Specifically, previously, the specific scheduling is only applied when a sequence of BroadcastOp, ExpandOp and ViewOp are detected in this order, just because that's how a repetition of a tensor is commonly represented. However, the only op that is absolutely necessary is the final reshape. As long as it meets the patterns for repetition, it should be sufficient to apply the scheduling. In fact, in a segment of a Llama RoPE forward, there's a segment input that has a broadcast ID, which is then expanded inside the segment and merged to realize a repetition. This case is not detectable as the segment lacks a BroadcastOp in the current main but is detected with this PR.

@naoyam
Copy link
Collaborator Author

naoyam commented Apr 26, 2025

!test --diff

@github-actions
Copy link

github-actions bot commented Apr 26, 2025

Review updated until commit e803fa0

Description

  • Enhanced static repeat detection to handle cases without BroadcastOp

  • Added partitionTvsById function for tensor partitioning

  • Updated getMaybeStaticRepeatInfo to detect reshape patterns more flexibly

  • Modified test cases to include scenarios with no BroadcastOp


Changes walkthrough 📝

Relevant files
Enhancement
resize.cpp
Add tensor partitioning and update repeat ID handling       

csrc/scheduler/resize.cpp

  • Added partitionTvsById function for tensor partitioning based on iter
    domain reachability
  • Updated ResizeScheduler::schedule to use getMaybeStaticRepeatInfo and
    handle repeat ID reordering
  • Modified tensor partitioning logic to group tensors with and without
    repeat IDs
  • +67/-37 
    static_repeat.cpp
    Simplify static repeat detection logic                                     

    csrc/scheduler/tools/static_repeat.cpp

  • Simplified getMaybeStaticRepeatInfo to detect reshape patterns without
    requiring BroadcastOp
  • Removed unnecessary checks for BroadcastOp, ExpandOp, and specific op
    sequences
  • Updated logic to identify factor and input IDs using
    PairwiseLogicalDomainMap
  • +48/-113
    static_repeat.h
    Update StaticRepeatInfo structure                                               

    csrc/scheduler/tools/static_repeat.h

  • Updated StaticRepeatInfo structure to include input_id, factor_id, and
    output_id
  • Removed redundant fields repeat_output_tv, reshape_output_tv, and
    repeat_tvs
  • +12/-14 
    Tests
    test_rope.cpp
    Add test cases for static repeat without BroadcastOp         

    tests/cpp/test_rope.cpp

  • Added new test case EndingRepeatWithNoBroadcastOp to verify static
    repeat detection without BroadcastOp
  • Modified existing test case EndingRepeat to include segment_set
    operation
  • +69/-1   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Logic Change

    The logic for detecting static repeats has been significantly changed. Ensure that the new logic correctly identifies all valid static repeat patterns and does not introduce false positives or negatives.

    // Partition a given set of tensors to two disjoint sets based on a
    // given iter domain and reachability from the iter domain. Returns two
    // vectors of tensors, first of which contains all tensors that has an
    // iter domain that is reachable from the given iter domain, whereas
    // the rest of tensors are all grouped into the second
    // list. Reachability is determined by using the permissive BFS
    // traversal on a given graph.
    std::pair<std::vector<TensorView*>, std::vector<TensorView*>> partitionTvsById(
        const std::vector<TensorView*> tvs,
        IterDomain* id,
        const ValGraph& graph) {
      ValGroups target_groups;
      for (auto tv : tvs) {
        target_groups.pushBack(graph.toGroups(tv->getLogicalDomain()));
      }
    
      const auto reachable_groups = getReachableValsFrom<ValGraphPermissiveBFS>(
          {graph.toGroup(id)},
          target_groups.vector(),
          /*allowed_direction=*/Direction::Undefined,
          graph);
      const std::unordered_set<ValGroup> reachable_group_set{
          reachable_groups.begin(), reachable_groups.end()};
    
      std::vector<TensorView*> reachable_tvs;
      std::vector<TensorView*> unreachable_tvs;
    
      for (auto tv : tvs) {
        if (std::ranges::any_of(
                tv->getLogicalDomain(), [&](IterDomain* logical_id) {
                  return reachable_group_set.contains(graph.toGroup(logical_id));
                })) {
          reachable_tvs.push_back(tv);
        } else {
          unreachable_tvs.push_back(tv);
        }
      }
    
      return std::make_pair(reachable_tvs, unreachable_tvs);
    }
    
    } // namespace
    
    void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
      FUSER_PERF_SCOPE("ResizeScheduler::schedule");
    
      FusionGuard fg(fusion);
    Simplification

    The detection of static repeats has been simplified. Verify that the new detection logic is still robust and covers all necessary cases, especially those involving complex patterns.

    // clang-format off
    /*
     * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
     * All rights reserved.
     * SPDX-License-Identifier: BSD-3-Clause
     */
    // clang-format on
    
    #include <ir/all_nodes.h>
    #include <ir/utils.h>
    #include <logical_domain_map.h>
    #include <scheduler/tools/static_repeat.h>
    
    namespace nvfuser {
    namespace scheduler_tools {
    
    std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
        TensorView* maybe_repeat_out_tv) {
      // Skip set ops if any (e.g., inserted by caching). Only Set
      // or SegmenterSet are considered.
      while (auto ldst =
                 dynamic_cast<LoadStoreOp*>(maybe_repeat_out_tv->definition())) {
        if (ldst->opType() != LoadStoreOpType::Set &&
            ldst->opType() != LoadStoreOpType::SegmenterSet) {
          break;
        }
        maybe_repeat_out_tv = ldst->in()->as<TensorView>();
      }
    
      // Detect reshape
      auto reshape = dynamic_cast<ViewOp*>(maybe_repeat_out_tv->definition());
      if (reshape == nullptr) {
        return std::nullopt;
      }
    
      auto reshape_in = reshape->input(0)->as<TensorView>();
      auto reshape_out = reshape->output(0)->as<TensorView>();
    
      auto reshape_exprs = DependencyCheck::getAllExprsBetween(
          {reshape_out->getRootDomain().begin(),
           reshape_out->getRootDomain().end()},
          {reshape_out->getLogicalDomain().begin(),
           reshape_out->getLogicalDomain().end()});
    
      if (reshape_exprs.size() != 1) {
        return std::nullopt;
      }
    
      auto reshape_merge = dynamic_cast<Merge*>(reshape_exprs.at(0));
      if (reshape_merge == nullptr) {
        return std::nullopt;
      }
    
      // Reshape of an expanded broadcast always generates a concrete
      // non-broadcast ID, so this check is not necessary, but just in
      // case in the future that may change.
      if (reshape_merge->out()->isBroadcast() ||
          reshape_merge->out()->hasExpandedExtent()) {
        return std::nullopt;
      }
    
      StaticRepeatInfo info;
    
      info.output_id = reshape_merge->out();
    
      const auto c2p =
          PairwiseLogicalDomainMap(reshape_in, reshape_out).mapConsumerToProducer();
    
      auto producer_merge_outer = c2p.at(reshape_merge->outer());
      auto producer_merge_inner = c2p.at(reshape_merge->inner());
      IterDomain* producer_factor_id = nullptr;
    
      if (producer_merge_outer->isBroadcast() &&
          producer_merge_outer->hasExpandedExtent() &&
          !producer_merge_inner->isBroadcast()) {
        // Inner ID is repeated by the factor of the outer extent
        info.input_id = reshape_merge->inner();
        info.factor_id = reshape_merge->outer();
        producer_factor_id = producer_merge_outer;
      } else if (
          producer_merge_inner->isBroadcast() &&
          producer_merge_inner->hasExpandedExtent() &&
          !producer_merge_outer->isBroadcast()) {
        // Outer ID is repeated by the factor of the inner extent
        info.input_id = reshape_merge->outer();
        info.factor_id = reshape_merge->inner();
        producer_factor_id = producer_merge_inner;
      } else {
        return std::nullopt;
      }
    
      // Check if the expanded ID has a static expanded extent
      if (!producer_factor_id->expandedExtent()->isConstInt()) {
        return std::nullopt;
      }
    
    New Test

    A new test case has been added for a scenario without a broadcast op. Ensure that this test case adequately covers the new logic and that the test is comprehensive.

      auto tv0 = makeContigConcreteTensor(shape1);
      fusion.addInput(tv0);
    
      auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
      auto tv2 = repeat(tv1, {2, 1});
      auto tv3 = segment_set(tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn(shape1, options);
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0});
      testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
    
      FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
      EXPECT_FALSE(runtime->isSegmented());
      const auto& heuristic_param =
          runtime->schedulerHeuristics()->heuristicsList().front();
      EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize);
      Fusion* scheduled_fusion = runtime->executors()
                                     .at(0)
                                     ->as<KernelExecutor>()
                                     ->compiledKernel()
                                     ->kernel();
    
      // Check the loop domain of the reference. It should look like:
      //
      // T4_g_float[iS19{2 ex 2}, iblockIdx.x22{8}, ithreadIdx.x23{128}] ca_pos( 3 )
      // produce_pos( 3 )
      //  logical domain : (iS17{( 2 * 8 )}, iS18{128})
      //  contiguity: t t
      //   Merge: iS20{8} and iS18{128} -> iS21{1024}
      //   Split: iS21{1024} by factor 128 -> iblockIdx.x22{8}, ithreadIdx.x23{128}
      //  loop domain : (iS19{2 ex 2}, iblockIdx.x22{8}, ithreadIdx.x23{128})
      //
      // iS19 is the repeat ID, which should be just a Serial ID with an
      // extent of 2.
      auto ref_tv = scheduled_fusion->outputs().at(0)->as<TensorView>();
      // The outermost loop ID should be a Serial ID with an extent of 2.
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->getParallelType(), ParallelType::Serial);
      EXPECT_TRUE(ref_tv->getLoopDomain().at(0)->extent()->isConstInt());
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->extent()->evaluate().as<int64_t>(), 2L);
    
      IdModel id_model(scheduled_fusion, /*build_graphs=*/false);
      const auto& exact_graph = id_model.buildExactGraph();
    
      const auto ref_loop = exact_graph.toGroups(ref_tv->getLoopDomain());
    
      // The other tensors, except for the pad output, should be fully inlined into
      // the reference tensor.
      for (auto tv : scheduled_fusion->allTvs()) {
        if (tv->isFusionInput()) {
          continue;
        }
        auto tv_loop = exact_graph.toGroups(tv->getLoopDomain());
        if (tv->definition() != nullptr && tv->definition()->isA<PadOp>()) {
          ValGroups ref_groups{ref_loop.begin() + 1, ref_loop.end()};
          // In the case of pad, the loop domain of the output tensor
          // should be mapped with the loop domain of the reference
          // without the outermost ID.
          EXPECT_EQ(tv_loop, ref_groups);
        } else {
          EXPECT_EQ(tv_loop, ref_loop);
          EXPECT_EQ(tv->getLoopDomain().size(), tv->getComputeAtPosition());
        }
      }
    }
    
    // Similar to EndingRepeat but with a broadcast ID already found in an
    // input tensor. A similar Pattern appears in the LitGPT Llama RoPE
    // module.
    TEST_F(RopeTest, EndingRepeatWithNoBroadcastOp) {
      auto fusion_ptr = std::make_unique<Fusion>();
      FusionGuard fg(fusion_ptr.get());
      Fusion& fusion = *fusion_ptr;
    
      std::vector<int64_t> shape1{3, 1, 200};
    
      auto tv0 = makeContigConcreteTensor(shape1);
      fusion.addInput(tv0);
    
      auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
      auto tv2 = expand(
          tv1,
          {IrBuilder::create<Val>(-1),
           IrBuilder::create<Val>(2),
           IrBuilder::create<Val>(-1)});
      auto tv3 =
          reshape(tv2, {IrBuilder::create<Val>(6), IrBuilder::create<Val>(-1)});
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto t0 = at::randn(shape1, options);
    
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto outputs = executor_cache.runFusionWithInputs({t0});
      testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
    
      FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
      EXPECT_FALSE(runtime->isSegmented());
      const auto& heuristic_param =
          runtime->schedulerHeuristics()->heuristicsList().front();
      EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Resize);
      Fusion* scheduled_fusion = runtime->executors()
                                     .at(0)
                                     ->as<KernelExecutor>()
                                     ->compiledKernel()
                                     ->kernel();
    
      // Similar to the EndingRepeat tensor, the repeat factor ID should
      // be placed at the outermost position.
      auto ref_tv = scheduled_fusion->outputs().at(0)->as<TensorView>();
      // The outermost loop ID should be a Serial ID with an extent of 2.
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->getParallelType(), ParallelType::Serial);
      EXPECT_TRUE(ref_tv->getLoopDomain().at(0)->extent()->isConstInt());
      EXPECT_EQ(
          ref_tv->getLoopDomain().at(0)->extent()->evaluate().as<int64_t>(), 2L);
    
      IdModel id_model(scheduled_fusion, /*build_graphs=*/false);
      const auto& exact_graph = id_model.buildExactGraph();
    
      const auto ref_loop = exact_graph.toGroups(ref_tv->getLoopDomain());
    
      // All of the tensors have a mapped ID as the factor ID, so they
      // should all have the same loop ID groups.
      for (auto tv : scheduled_fusion->allTvs()) {
        if (tv->isFusionInput()) {
          continue;
        }
        EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop);
        EXPECT_EQ(tv->getLoopDomain().size(), tv->getComputeAtPosition());
      }
    }
    
    } // namespace nvfuser

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 26, 2025

    !test --diff

    @github-actions
    Copy link

    Description

    • Refactored static repeat detection logic

    • Introduced partitionTvsById function for better TV partitioning

    • Simplified and optimized getMaybeStaticRepeatInfo function

    • Updated variable names for clarity


    Changes walkthrough 📝

    Relevant files
    Enhancement
    resize.cpp
    Enhance static repeat scheduling                                                 

    csrc/scheduler/resize.cpp

  • Added partitionTvsById function for partitioning TensorViews based on
    IterDomain
  • Updated schedule method to use partitionTvsById for TV partitioning
  • Renamed static_repeat_info to repeat_info for clarity
  • Updated logic to handle repeat ID movement and partitioning
  • +59/-32 
    static_repeat.cpp
    Simplify static repeat detection                                                 

    csrc/scheduler/tools/static_repeat.cpp

  • Simplified getMaybeStaticRepeatInfo function
  • Removed unnecessary checks and assumptions
  • Updated logic to detect static repeat patterns more accurately
  • Introduced PairwiseLogicalDomainMap for better domain mapping
  • +42/-112
    static_repeat.h
    Update StaticRepeatInfo structure                                               

    csrc/scheduler/tools/static_repeat.h

  • Updated StaticRepeatInfo structure with clearer variable names
  • Removed redundant fields and added new fields for better
    representation
  • +8/-13   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The partitionTvsById function does not handle cases where the ValGraph might not contain all the necessary mappings, which could lead to incorrect partitioning of TensorViews.

    std::pair<std::vector<TensorView*>, std::vector<TensorView*>> partitionTvsById(
        const std::vector<TensorView*> tvs,
        IterDomain* id,
        const ValGraph& graph) {
      ValGroups target_groups;
      for (auto tv : tvs) {
        target_groups.pushBack(graph.toGroups(tv->getLogicalDomain()));
      }
    
      const auto reachable_groups = getReachableValsFrom<ValGraphPermissiveBFS>(
          {graph.toGroup(id)},
          target_groups.vector(),
          /*allowed_direction=*/Direction::Undefined,
          graph);
      const std::unordered_set<ValGroup> reachable_group_set{
          reachable_groups.begin(), reachable_groups.end()};
    
      std::vector<TensorView*> reachable_tvs;
      std::vector<TensorView*> unreachable_tvs;
    
      for (auto tv : tvs) {
        if (std::ranges::any_of(
                tv->getLogicalDomain(), [&](IterDomain* logical_id) {
                  return reachable_group_set.contains(graph.toGroup(logical_id));
                })) {
          reachable_tvs.push_back(tv);
        } else {
          unreachable_tvs.push_back(tv);
        }
      }
    
      return std::make_pair(reachable_tvs, unreachable_tvs);
    Simplification

    The new implementation of getMaybeStaticRepeatInfo is more concise but could benefit from additional comments to explain the logic, especially around the use of PairwiseLogicalDomainMap.

    std::optional<StaticRepeatInfo> getMaybeStaticRepeatInfo(
        TensorView* maybe_repeat_out_tv) {
      // Skip a set if any (e.g., inserted by caching)
      if (auto ldst = dynamic_cast<LoadStoreOp*>(maybe_repeat_out_tv->definition());
          ldst != nullptr && ldst->opType() == LoadStoreOpType::Set) {
        maybe_repeat_out_tv = ldst->in()->as<TensorView>();
      }
    
      // Detect reshape
      auto reshape = dynamic_cast<ViewOp*>(maybe_repeat_out_tv->definition());
      if (reshape == nullptr) {
        return std::nullopt;
      }
    
      auto reshape_in = reshape->input(0)->as<TensorView>();
      auto reshape_out = reshape->output(0)->as<TensorView>();
    
      auto reshape_exprs = DependencyCheck::getAllExprsBetween(
          {reshape_out->getRootDomain().begin(),
           reshape_out->getRootDomain().end()},
          {reshape_out->getLogicalDomain().begin(),
           reshape_out->getLogicalDomain().end()});
    
      if (reshape_exprs.size() != 1) {
        return std::nullopt;
      }
    
      auto reshape_merge = dynamic_cast<Merge*>(reshape_exprs.at(0));
      if (reshape_merge == nullptr) {
        return std::nullopt;
      }
    
      // Reshape of an expanded broadcast always generates a concrete
      // non-broadcast ID, so this check is not necessary, but just in
      // case in the future that may change.
      if (reshape_merge->out()->isBroadcast() ||
          reshape_merge->out()->hasExpandedExtent()) {
        return std::nullopt;
      }
    
      StaticRepeatInfo info;
    
      info.output_id = reshape_merge->out();
    
      const auto c2p =
          PairwiseLogicalDomainMap(reshape_in, reshape_out).mapConsumerToProducer();
    
      auto producer_merge_outer = c2p.at(reshape_merge->outer());
      auto producer_merge_inner = c2p.at(reshape_merge->inner());
      IterDomain* producer_factor_id = nullptr;
    
      if (producer_merge_outer->isBroadcast() &&
          producer_merge_outer->hasExpandedExtent() &&
          !producer_merge_inner->isBroadcast()) {
        // Inner ID is repeated by the factor of the outer extent
        info.input_id = reshape_merge->inner();
        info.factor_id = reshape_merge->outer();
        producer_factor_id = producer_merge_outer;
      } else if (
          producer_merge_inner->isBroadcast() &&
          producer_merge_inner->hasExpandedExtent() &&
          !producer_merge_outer->isBroadcast()) {
        // Outer ID is repeated by the factor of the inner extent
        info.input_id = reshape_merge->outer();
        info.factor_id = reshape_merge->inner();
        producer_factor_id = producer_merge_inner;
      } else {
        return std::nullopt;
      }
    
      // Check if the expanded ID has a static expanded extent
      if (!producer_factor_id->expandedExtent()->isConstInt()) {
        return std::nullopt;
      }
    Documentation

    The updated StaticRepeatInfo struct fields need better documentation to clarify their roles and relationships within the scheduling process.

    // propagated to the rest of the tensors.
    //
    // TODO: Consider generalizing this heuristics to the other
    // schedulers.
    
    struct StaticRepeatInfo {
      // Root ID that is repeated
      IterDomain* input_id = nullptr;
      // Root ID that is originally an expanded broadcast
      IterDomain* factor_id = nullptr;
      // Logical repeated ID
      IterDomain* output_id = nullptr;

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 29, 2025

    !test --diff

    @naoyam naoyam marked this pull request as ready for review April 29, 2025 03:08
    @naoyam naoyam changed the title [WIP] Cleaning up scheduling of static repeat Cleaning up scheduling of static repeat Apr 29, 2025
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 29, 2025

    !test --diff

    @naoyam naoyam requested a review from jjsjann123 April 29, 2025 03:16
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 29, 2025

    !test --diff

    @naoyam naoyam mentioned this pull request Apr 29, 2025
    @naoyam naoyam added the rope label Apr 29, 2025
    @naoyam naoyam requested a review from protonu April 29, 2025 21:13
    auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
    auto tv2 = repeat(tv1, {2, 1});
    fusion.addOutput(tv2);
    auto tv3 = segment_set(tv2);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just to make sure the transformation is applied by ignoring plain set ops.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Out of curiosity, is there any specific reason we are using a segment_set, instead of a plain set?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is it to ensure this:

    // It is especially important to recognize this pattern when it
    // appears at the end of a pointwise fusion segment, where an output
    // is used as the reference tensor of scheduling the segment.

    If so, should the caching ops be mixed inside the repeat?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Though I guess this was the end of the fusion anyway - so that may not make much sense.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just because I saw a segment ending with segment_set, preceded by a repetition. I think that's one segment of a LitGPT Llama forward.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 30, 2025

    !test --diff

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM in general, I have quite some questions, but those are mostly just for my own curiosity.

    auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
    auto tv2 = repeat(tv1, {2, 1});
    fusion.addOutput(tv2);
    auto tv3 = segment_set(tv2);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Out of curiosity, is there any specific reason we are using a segment_set, instead of a plain set?

    const auto& [tvs_with_repeat_id, tvs_without_repeat_id] = partitionTvsById(
    all_tvs,
    repeat_info->factor_id,
    id_model->maybeBuildGraph(IdMappingMode::BROADCAST));
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    naive question for my own curiosity, do we need to map BROADCAST?

    In the added example

    1708   std::vector<int64_t> shape1{3, 1, 200};
    1709 
    1710   auto tv0 = makeContigConcreteTensor(shape1);
    1711   fusion.addInput(tv0);
    1712 
    1713   auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
    1714   auto tv2 = expand(
    1715       tv1,
    1716       {IrBuilder::create<Val>(-1),
    1717        IrBuilder::create<Val>(2),
    1718        IrBuilder::create<Val>(-1)});
    1719   auto tv3 =
    1720       reshape(tv2, {IrBuilder::create<Val>(6), IrBuilder::create<Val>(-1)});
    1721   fusion.addOutput(tv3);
    

    Say for tv1 [i0, b(1), i2], after the expand, we would have tv2 [i0, b(2), i2]
    The two broadcast ID in tv1 and tv2 would have different extent.

    Q1. IIUC, mapping with broadcast would allow us map those two together?
    Q2. Does it matter for us to group tv1 with the tvs_with_repeat_id, even though the it only contains the non-expanded factor_id?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    EXACT should work too. Previously, we schedule those tensors like tv1 and tv2 together with tv3, so using BROADCAST keeps the same behavior. I don't think there should be any actual difference in final performances.

    reshape_out = ldst->in()->as<TensorView>();
    repeat_tvs.insert(reshape_out);
    TensorView* maybe_repeat_out_tv) {
    // Skip set ops if any (e.g., inserted by caching). Only Set
    Copy link
    Collaborator

    @protonu protonu Apr 30, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I thought the deleted comment here was helpful - the bit about skipping caching ops.

    // output, it is likely there's a cache tv between expand_out and
    // repeat_out, so the following pattern should also be detected.
    //
    // broadcast_out = broadcast(input)
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    These 4 lines were helpful.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Not sure which 4 lines, but broadcast is no longer required.

    @@ -352,35 +397,33 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
    // detected. The repeat ID then just remains there with no
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit: move the def of repeat_info here - near the use.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'd keep it there as ref_tv is going to be transformed after that. It shouldn't affect the analysis, but there's no need to introduce an additional complexity.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 30, 2025

    !build

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Apr 30, 2025

    !build

    @naoyam naoyam merged commit b80c443 into main Apr 30, 2025
    12 of 13 checks passed
    @naoyam naoyam deleted the static_repeat_cleanup branch April 30, 2025 20:06
    naoyam added a commit that referenced this pull request May 2, 2025
    Stacked on top of #4325 
    
    If a repeat is moved to the end of a segment, the resize scheduler will
    take advantage of it.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants