Skip to content

Enable same padding for non-register sharing warp specialization#4395

Merged
rdspring1 merged 2 commits intomainfrom
generalize_select_warp_p2
May 13, 2025
Merged

Enable same padding for non-register sharing warp specialization#4395
rdspring1 merged 2 commits intomainfrom
generalize_select_warp_p2

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented May 8, 2025

This PR enforces same padding rules for non-register sharing warp specialization.

  • Replaced std::unordered_set<ParallelType> warp_specialized_types_ with std::optional<ParallelType> warp_specialized_parallel_type_ because we only support a single ParallelType.

@github-actions
Copy link

github-actions bot commented May 8, 2025

Review updated until commit bc4e084

Description

  • Replaced std::unordered_set<ParallelType> with std::optional<ParallelType> for warp specialization.

  • Added checks for multiple warp specialized axes and updated padding logic.

  • Updated test cases to reflect changes in warp specialization handling.


Changes walkthrough 📝

Relevant files
Enhancement
parallel_dimension_map.cpp
Refactor warp specialization handling                                       

csrc/parallel_dimension_map.cpp

  • Replaced warp_specialized_types_ with warp_specialized_parallel_type_.
  • Added error handling for multiple warp specialized axes.
  • Updated padding logic and variable names.
  • +19/-37 
    parallel_dimension_map.h
    Update header file for warp specialization                             

    csrc/parallel_dimension_map.h

  • Updated hasWarpSpecialization and added isWarpSpecialized method.
  • Renamed variables related to warp specialization padding.
  • +9/-8     
    Tests
    test_circular_buffering.cpp
    Update test cases for warp specialization                               

    tests/cpp/test_circular_buffering.cpp

  • Added testEnablesWarpSpecialization and testEnablesTIDx methods.
  • Updated test skips to check for warp specialization.
  • +16/-6   
    test_combined_inner_outer_reduction.cpp
    Update test cases for warp specialization                               

    tests/cpp/test_combined_inner_outer_reduction.cpp

    • Updated test skips to check for warp specialization.
    +13/-4   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The check if (!warp_specialized_parallel_type_.has_value()) in getWarpSpecializationPaddedVal might return 1, which could be incorrect if no warp specialization is intended. Ensure this behavior is correct or handle the case appropriately.

      return 1;
    }
    Test Coverage

    The tests skip scenarios where warp specialization is enabled. Verify that these scenarios are covered elsewhere or that the skips are justified.

    TEST_P(TmaCircularBufferingTest, Persistent) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      if (testEnablesWarpSpecialization()) {
        GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled.";
        return;
    Test Coverage

    The tests skip scenarios where warp specialization is enabled. Verify that these scenarios are covered elsewhere or that the skips are justified.

    TEST_P(TmaWarpSpecializedTest, SimpleFusion) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      auto [contig, ws_enabled, dtype, dim0, dim1] = GetParam();
    
      if (ws_enabled) {
        GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled.";
        return;
      }

    Base automatically changed from generalize_select_warp to main May 8, 2025 19:20
    @rdspring1 rdspring1 force-pushed the generalize_select_warp_p2 branch from a636d75 to b8d2f61 Compare May 8, 2025 19:23
    @rdspring1 rdspring1 marked this pull request as ready for review May 8, 2025 20:26
    @rdspring1 rdspring1 requested a review from zasdfgbnm May 8, 2025 20:26
    @rdspring1 rdspring1 force-pushed the generalize_select_warp_p2 branch from b8d2f61 to ffd0946 Compare May 8, 2025 21:03
    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @zasdfgbnm zasdfgbnm left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Nice cleanup!

    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 merged commit 1493be4 into main May 13, 2025
    53 checks passed
    @rdspring1 rdspring1 deleted the generalize_select_warp_p2 branch May 13, 2025 23:53
    samnordmann pushed a commit that referenced this pull request May 15, 2025
    This PR enforces same padding rules for non-register sharing warp
    specialization.
    
    * Replaced `std::unordered_set<ParallelType> warp_specialized_types_`
    with `std::optional<ParallelType> warp_specialized_parallel_type_`
    because we only support a single ParallelType.
    liqiangxl added a commit that referenced this pull request May 15, 2025
    New restriction on warp specialization was added in
    #4395
    Needs to temporarily skip `ThunderRMSNormBwd`
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants