Skip to content

Insert fences in insertRawThreadSynchronization#4810

Draft
jacobhinkle wants to merge 8 commits intomainfrom
jh/insert_fences_with_syncs
Draft

Insert fences in insertRawThreadSynchronization#4810
jacobhinkle wants to merge 8 commits intomainfrom
jh/insert_fences_with_syncs

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 21, 2025

Stacked on #4820

Whenever we insert a sync, this PR adds a simple analysis to determine whether a memory fence is required and if so it adds a FenceAsyncProxy.

Before this PR:

      block_sync::sync<false>(dim3(128, 2, 1));
      #pragma unroll
      for(nvfuser_index_t i56 = 0; i56 < 4; ++i56) {
        if (b18) {
          fenceAsyncProxy();
        }
        if (b23) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) }), (i14 + (8192 * i56)));
        }
      }

After this PR:

      block_sync::sync<false>(dim3(128, 2, 1));
      if (b18) {
        fenceAsyncProxy();
      }
      #pragma unroll
      for(nvfuser_index_t i56 = 0; i56 < 4; ++i56) {
        if (b23) {
          Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr15, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) }), (i14 + (8192 * i56)));
        }
      }

This is not sufficient to completely address #4808 because:

  1. It only affects syncs inserted in this pass, while mbarrier syncs are inserted in circular buffering as well.
  2. It does not affect the wgmma::fence which is inserted in another part of this pass
  3. It does not predicate the fence based on predicates of the consumers and even if it did, we do not use expr->predicate() for TMA stores yet. That predication happens in the unroll pass and an exception is made for ElectSync of TMA store given Skip ElectSync when creating predicate for TMA Store in PredicateCompute #4332.

Fixes #4814

Whenever we insert a sync, this PR adds a simple analysis to determine
whether a memory fence is required and if so it adds a
`FenceAsyncProxy`. This is not sufficient to completely address #4808
because:
1. It only affects syncs inserted in this pass, while mbarrier syncs are
   inserted in circular buffering as well.
2. It does not affect the `wgmma::fence` which is inserted in another
   part of this pass
3. It does not predicate the fence based on predicates of the consumers
   and even if it did, we do not use `expr->predicate()` for TMA stores
   yet.
@github-actions
Copy link

github-actions bot commented Jul 21, 2025

Review updated until commit 67285b3

Description

  • Insert FenceAsyncProxy before syncs when necessary

  • Add memory proxy analysis for async operations

  • Predicate FenceAsyncProxy for TMA stores

  • Update predicate handling for specific expressions


Changes walkthrough 📝

Relevant files
Enhancement
insert_syncs.cpp
Enhance sync insertion with memory proxy analysis               

csrc/device_lower/pass/insert_syncs.cpp

  • Removed unnecessary FenceAsyncProxy insertion for CpAsyncBulkStore
  • Added logic to determine if a FenceAsyncProxy is needed based on
    memory proxies
  • Inserted FenceAsyncProxy with warp select predicate for TMA stores
  • +57/-11 
    unroll.cpp
    Update predicate handling for specific expressions             

    csrc/device_lower/pass/unroll.cpp

    • Updated predicate handling to include FenceAsyncProxy
    +7/-5     
    utils.cpp
    Add memory proxy analysis for async operations                     

    csrc/device_lower/utils.cpp

  • Added getMemoryProxy function to determine memory proxy for
    expressions
  • +13/-0   
    utils.h
    Add memory proxy analysis for async operations                     

    csrc/device_lower/utils.h

    • Added MemoryProxy enum and getMemoryProxy function declaration
    +5/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The logic for determining whether a fence is needed seems complex and could have edge cases. Ensure that all scenarios are correctly handled, especially with different memory types and operations.

        lower_utils::getMemoryProxy(insert_before_expr);
    std::unordered_set<MemoryType> guarded_memtypes;
    bool needs_proxy_fence = false;
    for (Expr* write_expr : last_writes) {
      // Determine whether an implicit fence is guaranteed or if we need to
      // insert one
      if (auto* mma = dynamic_cast<MmaOp*>(write_expr);
          mma && mma->isHopper()) {
        // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-async-proxy
        continue;
      }
      if (ir_utils::isCpAsyncOp(write_expr) ||
          ir_utils::isCpAsyncBulk(write_expr)) {
        // https://docs.nvidia.com/cuda/parallel-thread-execution/#async-proxy
        continue;
      }
      if (lower_utils::getMemoryProxy(write_expr) != consumer_proxy) {
        needs_proxy_fence = true;
        // If this expression requires a fence, determine which memory space(s)
        // need to be fenced.
        for (Val* out_val : write_expr->outputs()) {
          if (auto* tv = dynamic_cast<TensorView*>(out_val)) {
            guarded_memtypes.insert(tv->getMemoryType());
          }
        }
      }
    }
    Code Clarity

    The new logic for inserting fences could be made more readable by breaking it into smaller functions or adding more comments to explain the decision-making process.

        lower_utils::getMemoryProxy(insert_before_expr);
    std::unordered_set<MemoryType> guarded_memtypes;
    bool needs_proxy_fence = false;
    for (Expr* write_expr : last_writes) {
      // Determine whether an implicit fence is guaranteed or if we need to
      // insert one
      if (auto* mma = dynamic_cast<MmaOp*>(write_expr);
          mma && mma->isHopper()) {
        // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-async-proxy
        continue;
      }
      if (ir_utils::isCpAsyncOp(write_expr) ||
          ir_utils::isCpAsyncBulk(write_expr)) {
        // https://docs.nvidia.com/cuda/parallel-thread-execution/#async-proxy
        continue;
      }
      if (lower_utils::getMemoryProxy(write_expr) != consumer_proxy) {
        needs_proxy_fence = true;
        // If this expression requires a fence, determine which memory space(s)
        // need to be fenced.
        for (Val* out_val : write_expr->outputs()) {
          if (auto* tv = dynamic_cast<TensorView*>(out_val)) {
            guarded_memtypes.insert(tv->getMemoryType());
          }
        }
      }
    }
    
    if (needs_proxy_fence) {
      NVF_ERROR(
          guarded_memtypes.size() == 1 &&
              guarded_memtypes.count(MemoryType::Shared) == 1,
          "We currently only support fence.proxy.async.shared::cta, but other "
          "memory types were detected.");
      Expr* fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
      // Predicate the fence to select the first warp. TMA store is warp
      // collective so ElectSync is not needed.
      Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
      Val* select_first_warp = IrBuilder::ltExpr(
          NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
      auto* select_warp_pred =
          IrBuilder::create<kir::Predicate>(select_first_warp);
      fence_async = fence_async->withPredicate(select_warp_pred);
      registerInsertBefore(place_before, fence_async, &placed_in_fl->body());
    }
    
    return placed_in_fl;
    Incomplete Implementation

    The getMemoryProxy function has a TODO comment indicating that operations accessing TensorMap should return MemoryProxy::TensorMap, but this is not implemented. Ensure that this is addressed or documented.

    MemoryProxy getMemoryProxy(Expr* expr) {
      if (ir_utils::isCpAsyncOp(expr) || ir_utils::isCpAsyncBulk(expr) ||
          expr->isOneOf<kir::AsyncCommit, kir::AsyncWait>()) {
        return MemoryProxy::Async;
      }
    
      // TODO: Any operation that accesses a TensorMap should return
      // MemoryProxy::TensorMap. I don't think we every explicitly access these
      // currently.
    
      return MemoryProxy::Generic;

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    Like #4804, this PR also causes a failure in HopperMatmulTest.PingPongPersistent unless the warp selection predicate is added. It seems we need to address predication of the FenceAsyncProxy first.

    @jacobhinkle jacobhinkle changed the title [WIP] Insert fences in insertRawThreadSynchronization Insert fences in insertRawThreadSynchronization Jul 22, 2025
    @jacobhinkle
    Copy link
    Collaborator Author

    I observed some slowdowns that I'm trying to understand. In particular, looking at 3648-480-8632-NT-bf16, we use the following params

    ===== Matmul Parameters ========
    
    MMA macro: Hopper_64_128_16
    CircularBufferOptions:
      circular_buffer_smem_write: true
      circular_buffer_smem_read: false
      smem_circular_buffer_stage: 6
      smem_circular_buffer_prefetch_gap: 1
    SupportedVectorization:
      a: 8
      b: 8
      epilogue: 8
    MatMulTileOptions: warp tile [64, 128, 64], CTA tile [128, 128, 64]
    Async global mem load: true
    Indexing mode: int32_t
    Tile rasterization order: row-major
    Grid swizzle factor: (1, 1)
    Tiling strategy: DistributeTilesAcrossSMs
    Buffering loop level: CTATiles
    Circular buffering strategy: WarpSpecialized
    __cluster_dims__(1, 2)
    Use shared memory epilogue: 1
    Promote re-use of prologue shared memory: 1
    Use ldmatrix/stmatrix in epilogue: 1
    Split-K factor: 1
    ====================================
    

    On TOT we have this epilogue section:

            block_sync::sync<false>(dim3(128, 2, 1));
            #pragma unroll
            for(nvfuser_index_t i56 = 0; i56 < 2; ++i56) {
              fenceAsyncProxy();
              if (b21) {
                Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr14, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) })  , (i13 + (8192 * i56)));
              }
            }
            block_sync::sync<false>(dim3(128, 2, 1));
            cpAsyncBulkCommitGroup();
            cpAsyncBulkWaitGroup<0LL>();

    When I switch this to the following I get a slowdown from 48.8 us to 61.8 us, i.e. a drop to 79% perf:

            block_sync::sync<false>(dim3(128, 2, 1));
            if (b21) {
              fenceAsyncProxy();
            }
            #pragma unroll
            for(nvfuser_index_t i56 = 0; i56 < 2; ++i56) {
              if (b21) {
                Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr14, (Array<int, 2, 1>{__to_int32((i41 + (64 * i56))), i43}) })  , (i13 + (8192 * i56)));
              }
            }
            block_sync::sync<false>(dim3(128, 2, 1));
            cpAsyncBulkCommitGroup();
            cpAsyncBulkWaitGroup<0LL>();

    This looks like the idea fence placement and predication to me so this is surprising, and perf is not recovered by adding .shared::cta to the instruction like in #4804.

    What is even more interesting is to look at the ncu profiles for these two situations:
    TOT:
    image
    With move and predication:
    image

    Relevent section of PTX diff

     $L__BB0_47:
    -       mov.u32         %r363, 1;
    +       mov.u32         %r364, 1;
    -       mov.u32         %r364, 256;
    +       mov.u32         %r365, 256;
            // begin inline asm
    -       bar.sync %r363, %r364;
    +       bar.sync %r364, %r365;
            // end inline asm
    +       or.pred         %p98, %p4, %p64;
    +       @%p98 bra       $L__BB0_49;
    +
    +       shl.b32         %r747, %r265, 7;
    +       add.s32         %r746, %r747, %r203;
            // begin inline asm
            fence.proxy.async;
     
            // end inline asm
    -       or.pred         %p98, %p4, %p64;
    -       @%p98 bra       $L__BB0_49;
    -
    -       shl.b32         %r750, %r262, 7;
    -       add.s32         %r749, %r750, %r199;
            // begin inline asm
    +       cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd122, {%r35, %r746}], [%r14];
    +       // end inline asm
    +       add.s32         %r371, %r14, 8192;
    +       add.s32         %r372, %r35, 64;
    +       // begin inline asm
    -       cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd121, {%r26, %r749}], [%r9];
    +       cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd122, {%r372, %r746}], [%r371];
            // end inline asm
     
     $L__BB0_49:
            // begin inline asm
    -       fence.proxy.async;
    -
    +       bar.sync %r364, %r365;
            // end inline asm
    -       @%p98 bra       $L__BB0_51;
    +       // begin inline asm
    +       cp.async.bulk.commit_group;
     
    -       shl.b32         %r748, %r262, 7;
    -       add.s32         %r747, %r748, %r199;
    -       add.s32         %r372, %r9, 8192;
    -       add.s32         %r373, %r26, 64;
    +       // end inline asm
            // begin inline asm
    -       cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%rd121, {%r373, %r747}], [%r372];
    +       cp.async.bulk.wait_group.read 0;
    +
            // end inline asm
     
    +$L__BB0_50:
    +       add.s32         %r749, %r749, 1;
    +       setp.lt.s32     %p99, %r749, %r11;
    +       @%p99 bra       $L__BB0_27;
    +

    @jacobhinkle
    Copy link
    Collaborator Author

    I believe the issue is that these changes sometimes are resulting in thread divergence during UTMASTG, which is a warp-collective instruction. Originally, we do not have such a case and the instruction has a uniform predicate. After, in the bad perf cases there is no such predicate and the predicated branch is thread-local i.e. potential thread divergence (though none is likely since we know the predicate is constant across the warp). Predicating the fence in place leads to this slowdown and observed non-uniform predication but then shfl_sync warp-broadcast of its predicate (early in the kernel preamble) recovers perf and the uniform register predicate. I have not been able to both move the fence up one level and maintain uniformity/convergence yet in this pathological example. Note that this might not be so bad if we can use an epilogue tile so that we always only have a single TMA store instruction...

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Insert FenceAsyncProxy just after RAW sync for TMA store

    1 participant