Skip to content

Add warp select predicate to TMA store's FenceAsyncProxy#4820

Draft
jacobhinkle wants to merge 1 commit intomainfrom
jh/predicate_proxy_fence
Draft

Add warp select predicate to TMA store's FenceAsyncProxy#4820
jacobhinkle wants to merge 1 commit intomainfrom
jh/predicate_proxy_fence

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 22, 2025

Before this PR:

  bool b18;
  b18 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
  bool b20;
  b20 = ((nvfuser_index_t)threadIdx.y) < 2;
  bool b23;
  b23 = b18 && b20;
  // ...
      block_sync::sync<false>(dim3(128, 2, 1));
      #pragma unroll
      for(nvfuser_index_t i56 = 0; i56 < 4; ++i56) {
        fenceAsyncProxy();
        if (b23) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) }), (i14 + (8192 * i56)));
        }
      }

After this PR:

  bool b18;
  b18 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
  bool b20;
  b20 = ((nvfuser_index_t)threadIdx.y) < 2;
  bool b23;
  b23 = b18 && b20;
  // ...
      block_sync::sync<false>(dim3(128, 2, 1));
      #pragma unroll
      for(nvfuser_index_t i56 = 0; i56 < 4; ++i56) {
        if (b18) {
          fenceAsyncProxy();
        }
        if (b23) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) }), (i14 + (8192 * i56)));
        }
      }

This predicate suffices for warp specialized TMA store in Hopper matmul. I have not yet checked the TMA store use cases in normalization.

Timing of MLPBenchmarkTest.FwdGEMM/persistent_warpspec:

  • main: 944 us
  • this PR: 942 us
  • cublas: 736 us

That example has a large main loop (80 stages)

Timing of HopperMatmulTest.PingPongPersistent:

  • main: 44.3 us
  • this PR: 41.2 us
  • cublas: 21.5 us

Here we see a measurable difference. The main advantage of this PR is that it enables #4804 and #4810 which will bring larger speedups.

Fixes #4813

@github-actions
Copy link

Description

  • Add warp select predicate to FenceAsyncProxy for TMA store

  • Update UnrollPass to handle FenceAsyncProxy predicates


Changes walkthrough 📝

Relevant files
Enhancement
insert_syncs.cpp
Predicate FenceAsyncProxy for first warp                                 

csrc/device_lower/pass/insert_syncs.cpp

  • Create FenceAsyncProxy with a predicate to select the first warp
  • Define warp_size and select_first_warp conditions
  • Apply predicate to FenceAsyncProxy
  • +9/-1     
    unroll.cpp
    Handle FenceAsyncProxy predicates in UnrollPass                   

    csrc/device_lower/pass/unroll.cpp

  • Update condition to handle FenceAsyncProxy predicates
  • Include FenceAsyncProxy in the list of expressions with predicates
  • +7/-5     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The predicate for FenceAsyncProxy is set to only the first warp. This might not be sufficient for all TMA store use cases, especially if the store is not warp specialized or if other warps need to perform the fence.

    // Predicate the fence to select the first warp. TMA store is warp
    // collective so ElectSync is not needed.
    Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
    Val* select_first_warp = IrBuilder::ltExpr(
        NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
    auto* select_warp_pred =
        IrBuilder::create<kir::Predicate>(select_first_warp);
    fence_async = fence_async->withPredicate(select_warp_pred);
    Code Clarity

    The condition for adding predicates to kir::MBarrierInit, kir::MBarrierInvalidate, and kir::FenceAsyncProxy is now combined. Ensure that this change does not inadvertently skip adding predicates to other expressions that require them.

    if (expr->predicate() != nullptr &&
        expr->isOneOf<
            kir::MBarrierInit,
            kir::MBarrierInvalidate,
            kir::FenceAsyncProxy>()) {
      kir::IfThenElse* inline_ite =

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Predicate FenceAsyncProxy to match TMA store predicate

    1 participant