Skip to content

Select by warp id in AsyncWarp if Register Sharing is enabled#4334

Merged
rdspring1 merged 13 commits intomainfrom
generalize_select_warp
May 8, 2025
Merged

Select by warp id in AsyncWarp if Register Sharing is enabled#4334
rdspring1 merged 13 commits intomainfrom
generalize_select_warp

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Apr 29, 2025

Background: Lowering pads a thread block ParallelType for the mbarrier async operations in the fusion with WarpSpecialized circular buffering.

Problem: Picking Warps

  • Picking either first or last warp along WarpSpecialized Axis is arbitrary.
  • Need to select individual warps for blackwell.

Solution - How to pick warp and threads?

  • Get linear index for AsyncWarp
def getLinearIndex():
  index = 0
  extent = 1
  for pt in [ParallelType.TIDx, ParallelType.TIDy, ParallelType.TIDz]:
    if pt is not active:
      continue
    if pt is trivial:
      continue
    pt_index = getIndex(pt)
    if pt is WarpSpecialized.on:
       pt_index -= original_cta_axis_size
    index += pt_index * getExtent(pt)
    extent *= getExtent(pt)
  return index
  • Then, select warp.
warp_id = (async_linear_index / 32)
if warp_id == 0 and elect-sync:
  # tma load
elif warp_id == 1 and elect-sync:
  # utcmma

Problem: ElectSync doesn't work if blockDim.x < 32

Solution: Replace ElectSync with thread_id == 0

warp_id = (async_linear_index / 32)
thread_id = (async_linear_index % 32)
if warp_id == 0 and thread_id == 0:
  # tma load
elif warp_id == 1 and thread_id == 0:
  # utcmma

Code snippet from HopperMatmulTest/MLPGemmPersistentBroadcastInputs.NumWarpGroups/2

  • CTA (128, 3, 1)
Details
      for(nvfuser_index_t i32 = 0; i32 < i6; ++i32) {
        nvfuser_index_t i33;
        i33 = (i30 + i32) % 3;
        if ((((((nvfuser_index_t)threadIdx.x) / 32ULL) == 0ULL) && Hopper::electSync(4294967295U))) {
          mbarrier::waitParity(toSmem((&T9[((((i6 * i25) + i32) % 3) + 3LL)])), (uint32_t)(((((i6 * i25) + i32) / 3) % 2)));
          mbarrier::arriveExpectTX(toSmem((&T9[(((i6 * i25) + i32) % 3)])), 32768U);
          Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, (Array<int, 2, 1>{(int32_t)((64 * i32)), i29}), toSmem((&T9[(((i6 * i25) + i32) % 3)])) }), (i8 + (32768 * i33)));
          mbarrier::arriveExpectTX(toSmem((&T9[(((i6 * i25) + i32) % 3)])), 16384U);
          Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr9, (Array<int, 2, 1>{(int32_t)((64 * i32)), i31}), toSmem((&T9[(((i6 * i25) + i32) % 3)])) }), (i10 + (16384 * i33)));
        }

Code snippet from Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_500_N_2048_WarpSpecializedOnTIDyRegisterSharing_64_168_CpAsyncBulkTensorTile

  • CTA (32, 16, 1)
Details
      if ((((((((nvfuser_index_t)threadIdx.x) + (32 * ((nvfuser_index_t)threadIdx.y))) + -512) / 32ULL) == 0ULL) && Hopper::electSync(4294967295U))) {
        mbarrier::waitParity(toSmem((&T12[((i26 % 2) + 2LL)])), (uint32_t)(((i26 / 2) % 2)));
        mbarrier::arriveExpectTX(toSmem((&T12[(i26 % 2)])), 16384U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr6, (Array<int, 2, 1>{i8, (int32_t)((64 * i26))}), toSmem((&T12[(i26 % 2)])) }), (i9 + i27));
        mbarrier::arriveExpectTX(toSmem((&T12[(i26 % 2)])), 16384U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr10, (Array<int, 2, 1>{(int32_t)((64 * i26)), i12}), toSmem((&T12[(i26 % 2)])) }), (i13 + i27));
      }

@github-actions
Copy link

github-actions bot commented Apr 29, 2025

Review updated until commit 56b6c0b

Description

  • Added linear thread index calculation for AsyncWarp

  • Enhanced warp selection logic for AsyncWarp

  • Introduced condition to use ElectSync based on warp specialization

  • Added exception handling for invalid CTA shapes


Changes walkthrough 📝

Relevant files
Enhancement
parallel_dimension_map.cpp
Add linear index and ElectSync condition for AsyncWarp     

csrc/parallel_dimension_map.cpp

  • Added getLinearThreadIndexAsync to calculate linear index for
    AsyncWarp threads
  • Added canUseElectSyncInAsyncWarp to determine if ElectSync can be used
  • +54/-0   
    predicate_compute.cpp
    Enhance predicate creation for AsyncWarp                                 

    csrc/predicate_compute.cpp

  • Created createElectSyncExpr to generate ElectSync expression
  • Modified selectFirstWarpElectSyncPredicate to use ElectSync if not
    warp collective
  • Added createElectSyncPredicateAsync for AsyncWarp specific predicate
    creation
  • Updated createElectSyncPredicate to handle AsyncWarp case
  • Modified createMultipleExpressionElectSync to handle AsyncWarp case
  • +64/-57 
    parallel_dimension_map.h
    Add declarations for AsyncWarp enhancements                           

    csrc/parallel_dimension_map.h

  • Added getLinearThreadIndexAsync declaration
  • Added canUseElectSyncInAsyncWarp declaration
  • +7/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Performance Goal

    The PR description mentions performance improvements but does not provide specific performance metrics or a clear performance goal. Ensure that a clear performance goal is set and that feedback was sought early.

    // For warp-specialization, the CTA is padded so the AsyncWarp contains 128
    // threads. This function maps the AsyncWarp CTA to a linear index from
    // [0, 128). It is used to divide AsyncWarp into four independent warps.
    Val* ParallelDimensionMap::getLinearThreadIndexAsync() const {
      Val* index = GpuLower::current()->kernel()->zeroVal();
      Val* extent = GpuLower::current()->kernel()->oneVal();
    
      for (auto pt : kParallelTypeTIDs) {
        // For warp-specialization, an axis is padded so the AsyncWarp contains
        // 128 threads.
        Val* extent_for_pdim = getRawAsync(pt);
        // short-circuit: extent_for_pdim is not used in kernel.
        if (extent_for_pdim == nullptr) {
          continue;
        }
        // short-circuit: extent_for_pdim is trivial.
        if (extent_for_pdim->isConstScalar() &&
            extent_for_pdim->evaluate().as<int64_t>() == 1) {
          continue;
        }
        Val* pt_index = NamedScalar::getParallelIndex(pt);
        // Map the padded parallel index to [0, padded_value] range, so the linear
        // index will be in range of [0, 128).
        if (warp_specialized_types_.count(pt)) {
          pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt));
        }
        index = SimplifyingIrBuilder::addExpr(
            index, SimplifyingIrBuilder::mulExpr(pt_index, extent));
        extent = SimplifyingIrBuilder::mulExpr(extent, extent_for_pdim);
      }
      return index;
    }
    ElectSync Usage

    The PR introduces changes to the usage of ElectSync in createElectSyncPredicateAsync. Ensure that the new approach is thoroughly tested and that the performance impact is evaluated, especially when ElectSync is not used.

      Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
      Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
    
      const ParallelDimensionMap& pdim_map =
          GpuLower::current()->parallelDimensionMap();
      Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync();
      Val* warp_id =
          SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size);
      // TODO Only select first warp now
      Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero);
    
      // Use elect-sync if available
      if (pdim_map.canUseElectSyncInAsyncWarp()) {
        return SimplifyingIrBuilder::logicalAndExpr(
            select_warp, createElectSyncExpr());
      }
    
      // Warp Specialized ParallelType is ThreadIdx.x and it contains less than 32
      // threads, so manually select first thread in warp.
      Val* thread_id =
          SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size);
      Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero);
      return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread);
    }
    Code Duplication

    The function selectFirstWarpElectSyncPredicate is similar to the logic in createElectSyncPredicateAsync. Consider refactoring to avoid code duplication and improve maintainability.

    Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) {
      Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
      Val* select_first_warp = IrBuilder::ltExpr(
          NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
    
      // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not
      // necessary.
      if (is_warp_collective) {
        return select_first_warp;
      }

    @rdspring1 rdspring1 force-pushed the rename_load_to_async branch from af66969 to e937d1f Compare April 30, 2025 01:31
    Base automatically changed from rename_load_to_async to main April 30, 2025 17:14
    @rdspring1 rdspring1 force-pushed the generalize_select_warp branch 2 times, most recently from 2403eba to 78acce4 Compare April 30, 2025 18:55
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 force-pushed the generalize_select_warp branch from 78acce4 to 05d6bbb Compare May 1, 2025 03:34
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 force-pushed the generalize_select_warp branch from 76f7e97 to 60e5ba4 Compare May 3, 2025 01:34
    @rdspring1 rdspring1 changed the title Select by warp id in AsyncWarp Select by warp id in AsyncWarp if Register Sharing is enabled May 7, 2025
    @rdspring1 rdspring1 marked this pull request as ready for review May 7, 2025 19:13
    @rdspring1 rdspring1 requested a review from zasdfgbnm May 7, 2025 19:53
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1
    Copy link
    Collaborator Author

    @zasdfgbnm I had to push 56b6c0b because I ran into the ElectSync doesn't work if blockDim.x < 32 with non-register sharing warp specialization.

    e.g., BIDx = 33 and BIDy = 16 yields only BIDx=1 and BIDy=16 active threads in AsyncWarp when WarpSpecialized.on == ParallelType::TIDx. There are less than 32 active threads, causing elect sync to stall.

    I can clean things up with #4395, which keeps the same padding rules even when register sharing is disabled.

    TL;DR: Having different padding rules for register sharing is annoying.

    @rdspring1 rdspring1 merged commit 1f4eae1 into main May 8, 2025
    53 checks passed
    @rdspring1 rdspring1 deleted the generalize_select_warp branch May 8, 2025 19:20
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants