Skip to content

Skip ElectSync when creating predicate for TMA Store in PredicateCompute#4332

Merged
rdspring1 merged 3 commits intomainfrom
select_warp_predicate
Apr 30, 2025
Merged

Skip ElectSync when creating predicate for TMA Store in PredicateCompute#4332
rdspring1 merged 3 commits intomainfrom
select_warp_predicate

Conversation

@rdspring1
Copy link
Collaborator

This PR changes createElectSyncPredicate to skip adding ElectSync to TMA Store expressions.

  • TMA Store is a warp-collective, so it is issued by a single warp. Using ElectSync to pick a single thread is unnecessary.

Review of ElectSync Predicate Handling

ElectSync Predicate with Expression

  1. In the Unroll pass, all non-circular buffered Async operations are assigned the ElectSync Predicate with their expression. See https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/pass/unroll.cpp#L160-L171.
  2. Any ElectSync Predicate with its expression is handled by PredicateCompute::createSingleExpressionElectSync.
  3. createSingleExpressionElectSync uses the predicate's expression to determine if it is TMA Store.
  4. Then, createElectSyncPredicate will skip the ElectSync if it is a TMA Store. The logic to select a warp is the same.

Expression-Less ElectSync Predicate

Test Example

  • HopperMatmulTest/MLPGemmPersistentBroadcastInputs.NumWarpGroups/2

Nsys NvProf

PR nvjet nvfuser %
This PR 1.44 1.49 96.6
ToT 1.45 1.49 97.33

CUDA Kernel without ElectSync

     bool b17 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
     bool b19 = ((nvfuser_index_t)threadIdx.y) < 2;
     bool b22 = b17 && b19
      #pragma unroll
      for(nvfuser_index_t i62 = 0; i62 < 4; ++i62) {
        fenceAsyncProxy();
        if (b22) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{(int32_t)((i42 + (64 * i62))), i45}) }), (i14 + (8192 * i62)));
        }
      }

CUDA Kernel with ElectSync

      bool b17 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
      bool b19 = ((nvfuser_index_t)threadIdx.y) < 2;
      for(nvfuser_index_t i61 = 0; i61 < 4; ++i61) {
        fenceAsyncProxy();
        if (((Hopper::electSync(4294967295U) && b17) && b19)) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{(int32_t)((i41 + (64 * i61))), i44}) }), (i14 + (8192 * i61)));
        }

* TMA Store is a warp-collective, so it is issued by a single warp.
* Using ElectSync to pick a single thread is unnecessary.
@rdspring1
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 29, 2025

Review updated until commit ee8f34c

Description

  • Skip ElectSync for TMA Store expressions

  • Update createElectSyncPredicate to accept is_warp_collective parameter

  • Modify TMAPredicateChecker to handle TMA Store predicates


Changes walkthrough 📝

Relevant files
Enhancement
predicate_compute.cpp
Enhance ElectSync handling for TMA Store                                 

csrc/predicate_compute.cpp

  • Added is_warp_collective parameter to createElectSyncPredicate
  • Short-circuit ElectSync creation for TMA Store expressions
  • Updated createSingleExpressionElectSync to check for TMA Store
  • +37/-19 
    test_memory.cpp
    Update TMA Store predicate checks                                               

    tests/cpp/test_memory.cpp

  • Added is_tma_store parameter to TMAPredicateChecker constructor
  • Updated checkPredicate to accept is_tma_store parameter
  • Modified TMA Store tests to use is_tma_store flag
  • +47/-7   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Logic Consistency

    Ensure that the logic for skipping ElectSync for TMA Store is consistent across all relevant functions and that it does not inadvertently affect other operations.

    Val* createElectSyncPredicate(
        bool use_first_warp = true,
        bool is_warp_collective = false) {
      // If TIDx is used for both computation and TMA load, we should select a
      // thread from the last warp along TIDx.
      Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
      auto select_warp = use_first_warp
          ? IrBuilder::ltExpr(
                NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size)
          : IrBuilder::geExpr(
                NamedScalar::getParallelIndex(ParallelType::TIDx),
                IrBuilder::addExpr(
                    NamedScalar::getParallelDim(ParallelType::TIDx),
                    IrBuilder::create<Val>(-32L, PrimDataType::Index)));
    
      // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not
      // necessary.
      if (is_warp_collective) {
        return select_warp;
      }
    
      // Create ElectSync Predicate to pick any thread in the warp
      Val* full_mask_val = IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);
      Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
      IrBuilder::create<UnaryOp>(
          UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
      return SimplifyingIrBuilder::logicalAndExpr(elect_sync_val, select_warp);
    }
    Test Coverage

    Verify that the added test cases cover all scenarios, including edge cases, for TMA Store operations to ensure the correctness of the changes.

    class TMAPredicateChecker : private kir::IrVisitor {
      int64_t num_threads_;
      int64_t cta_threads_;
      bool is_tma_store_;
      TMAPredicateChecker(
          int64_t num_threads,
          int64_t cta_threads,
          bool is_tma_store)
          : num_threads_(num_threads),
            cta_threads_(cta_threads),
            is_tma_store_(is_tma_store) {}
    
      kir::Predicate* pred_ = nullptr;
    
      using kir::IrVisitor::dispatch;
    
      void dispatch(Expr* expr) final {
        if (expr->isA<ForLoop>() || expr->isA<kir::IfThenElse>()) {
          kir::Predicate* prev_pred = nullptr;
          if (expr->isA<kir::IfThenElse>()) {
            auto ite = expr->as<kir::IfThenElse>();
            prev_pred = pred_;
            pred_ = ite->predicate();
          }
          kir::IrVisitor::dispatch(expr);
          if (expr->isA<kir::IfThenElse>()) {
            pred_ = prev_pred;
          }
          return;
        }
        if (!ir_utils::isCpAsyncBulk(expr)) {
          return;
        }
    
        if (num_threads_ == 0) {
          if (pred_ == nullptr) {
            return;
          }
        }
    
        ASSERT_NE(pred_, nullptr);
        auto cond = pred_->value();
        ASSERT_NE(cond, nullptr);
    
        // Handle TMA Store first
        if (is_tma_store_) {
          if (cta_threads_ <= 32) {
            EXPECT_TRUE(cond->isTrue());
          } else {
            auto def = dynamic_cast<BinaryOp*>(cond->definition());
            ASSERT_TRUE(def != nullptr);
            EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT);
            auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
            auto rhs = def->rhs();
            ASSERT_TRUE(lhs != nullptr);
            ASSERT_TRUE(rhs != nullptr);
            EXPECT_TRUE(lhs->isThreadIdx());
            EXPECT_TRUE(rhs->isConstInt());
            EXPECT_EQ(rhs->value(), 32);
          }
          return;
        }
    
        // Then, handle TMA Load
        if (num_threads_ == 0) {
          EXPECT_TRUE(cond->isTrue());
        } else if (is_tma_store_ && cta_threads_ <= 32) {
          EXPECT_TRUE(cond->isTrue());
        } else if (num_threads_ == 1 && cta_threads_ > 32) {
          auto def = dynamic_cast<BinaryOp*>(cond->definition());
          ASSERT_TRUE(def != nullptr);
          EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LogicalAnd);
          auto lhs = def->lhs();
          auto rhs = def->rhs();
          ASSERT_TRUE(lhs != nullptr);
          auto lhs_def = dynamic_cast<UnaryOp*>(lhs->definition());
          EXPECT_TRUE(lhs_def->getUnaryOpType() == UnaryOpType::ElectSync);
          ASSERT_TRUE(rhs != nullptr);
          auto rhs_def = dynamic_cast<BinaryOp*>(rhs->definition());
          EXPECT_TRUE(rhs_def->getBinaryOpType() == BinaryOpType::LT);
          auto lhs_rhs = dynamic_cast<NamedScalar*>(rhs_def->lhs());
          auto rhs_rhs = rhs_def->rhs();
          ASSERT_TRUE(lhs_rhs != nullptr);
          ASSERT_TRUE(rhs_rhs != nullptr);
          EXPECT_TRUE(lhs_rhs->isThreadIdx());
          EXPECT_TRUE(rhs_rhs->isConstInt());
          EXPECT_EQ(rhs_rhs->value(), 32);
        } else if (num_threads_ == 1 && cta_threads_ == 32) {
          auto def = dynamic_cast<UnaryOp*>(cond->definition());
          ASSERT_TRUE(def != nullptr);
          EXPECT_TRUE(def->getUnaryOpType() == UnaryOpType::ElectSync);
        } else if (num_threads_ == 1 && cta_threads_ < 32) {
          auto def = dynamic_cast<BinaryOp*>(cond->definition());
          ASSERT_TRUE(def != nullptr);
          EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::Eq);
          auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
          auto rhs = def->rhs();
          ASSERT_TRUE(lhs != nullptr);
          ASSERT_TRUE(rhs != nullptr);
          EXPECT_TRUE(lhs->isThreadIdx());
          EXPECT_TRUE(rhs->isZeroInt());
        } else {
          auto def = dynamic_cast<BinaryOp*>(cond->definition());
          ASSERT_TRUE(def != nullptr);
          EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT);
          auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
          auto rhs = def->rhs();
          ASSERT_TRUE(lhs != nullptr);
          ASSERT_TRUE(rhs != nullptr);
          EXPECT_TRUE(lhs->isThreadIdx());
          EXPECT_TRUE(rhs->isConstInt());
          EXPECT_EQ(rhs->value(), num_threads_);
        }
      }
    
     public:
      // Check that TMA is predicated with things like "tidx < num_threads".
      // num_threads == 0 is reserved for no predication.
      static void checkPredicate(
          kir::Kernel* kernel,
          int64_t num_threads,
          int64_t cta_threads = -1,
          bool is_tma_store = false) {
        TMAPredicateChecker checker(num_threads, cta_threads, is_tma_store);
        checker.handle(kernel->topLevelExprs());
      }
    };
    Redundant Checks

    Check for any redundant checks or conditions in the test cases that might be simplified or removed to improve clarity and maintainability.

         }
       }
    
       ASSERT_NE(pred_, nullptr);
       auto cond = pred_->value();
       ASSERT_NE(cond, nullptr);
    
       // Handle TMA Store first
       if (is_tma_store_) {
         if (cta_threads_ <= 32) {
           EXPECT_TRUE(cond->isTrue());
         } else {
           auto def = dynamic_cast<BinaryOp*>(cond->definition());
           ASSERT_TRUE(def != nullptr);
           EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT);
           auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
           auto rhs = def->rhs();
           ASSERT_TRUE(lhs != nullptr);
           ASSERT_TRUE(rhs != nullptr);
           EXPECT_TRUE(lhs->isThreadIdx());
           EXPECT_TRUE(rhs->isConstInt());
           EXPECT_EQ(rhs->value(), 32);
         }
         return;
       }
    
       // Then, handle TMA Load
       if (num_threads_ == 0) {
         EXPECT_TRUE(cond->isTrue());
       } else if (is_tma_store_ && cta_threads_ <= 32) {
         EXPECT_TRUE(cond->isTrue());
       } else if (num_threads_ == 1 && cta_threads_ > 32) {
         auto def = dynamic_cast<BinaryOp*>(cond->definition());
         ASSERT_TRUE(def != nullptr);
         EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LogicalAnd);
         auto lhs = def->lhs();
         auto rhs = def->rhs();
         ASSERT_TRUE(lhs != nullptr);
         auto lhs_def = dynamic_cast<UnaryOp*>(lhs->definition());
         EXPECT_TRUE(lhs_def->getUnaryOpType() == UnaryOpType::ElectSync);
         ASSERT_TRUE(rhs != nullptr);
         auto rhs_def = dynamic_cast<BinaryOp*>(rhs->definition());
         EXPECT_TRUE(rhs_def->getBinaryOpType() == BinaryOpType::LT);
         auto lhs_rhs = dynamic_cast<NamedScalar*>(rhs_def->lhs());
         auto rhs_rhs = rhs_def->rhs();
         ASSERT_TRUE(lhs_rhs != nullptr);
         ASSERT_TRUE(rhs_rhs != nullptr);
         EXPECT_TRUE(lhs_rhs->isThreadIdx());
         EXPECT_TRUE(rhs_rhs->isConstInt());
         EXPECT_EQ(rhs_rhs->value(), 32);
       } else if (num_threads_ == 1 && cta_threads_ == 32) {
         auto def = dynamic_cast<UnaryOp*>(cond->definition());
         ASSERT_TRUE(def != nullptr);
         EXPECT_TRUE(def->getUnaryOpType() == UnaryOpType::ElectSync);
       } else if (num_threads_ == 1 && cta_threads_ < 32) {
         auto def = dynamic_cast<BinaryOp*>(cond->definition());
         ASSERT_TRUE(def != nullptr);
         EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::Eq);
         auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
         auto rhs = def->rhs();
         ASSERT_TRUE(lhs != nullptr);
         ASSERT_TRUE(rhs != nullptr);
         EXPECT_TRUE(lhs->isThreadIdx());
         EXPECT_TRUE(rhs->isZeroInt());
       } else {
         auto def = dynamic_cast<BinaryOp*>(cond->definition());
         ASSERT_TRUE(def != nullptr);
         EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT);
         auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
         auto rhs = def->rhs();
         ASSERT_TRUE(lhs != nullptr);
         ASSERT_TRUE(rhs != nullptr);
         EXPECT_TRUE(lhs->isThreadIdx());
         EXPECT_TRUE(rhs->isConstInt());
         EXPECT_EQ(rhs->value(), num_threads_);
       }
     }
    
    public:

    @rdspring1
    Copy link
    Collaborator Author

    For future refactor, perhaps we should create a separate predicate type such as PredicateType::SingleWarp.

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 requested a review from zasdfgbnm April 30, 2025 17:36
    @zasdfgbnm
    Copy link
    Collaborator

    For future refactor, perhaps we should create a separate predicate type such as PredicateType::SingleWarp.

    Does it make sense to have two predicate types, or just one? If we only need one, then probably we can name it into something like PredicateType::PickThreads.

    @rdspring1
    Copy link
    Collaborator Author

    rdspring1 commented Apr 30, 2025

    Does it make sense to have two predicate types, or just one? If we only need one, then probably we can name it into something like PredicateType::PickThreads.

    PickWarp selects a single warp. PickThread selects a single thread from a warp. They would share the same logic, so I'd prefer only maintaining one.

    Yes, PredicateType::ElectSync has become that in my head already.

    @rdspring1 rdspring1 merged commit f800edb into main Apr 30, 2025
    52 of 53 checks passed
    @rdspring1 rdspring1 deleted the select_warp_predicate branch April 30, 2025 18:30
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants