Skip to content

Restrict fence.proxy.async to shared::cta#4804

Draft
jacobhinkle wants to merge 2 commits intomainfrom
jh/fence_proxy_shared
Draft

Restrict fence.proxy.async to shared::cta#4804
jacobhinkle wants to merge 2 commits intomainfrom
jh/fence_proxy_shared

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 20, 2025

Stacked on #4820

We only currently use the kir::FenceAsyncProxy expression type to avoid shared memory WAR hazards. Adding this modifier to the PTX instruction means we don't need to wait for gmem writes to be available to the epilogue threads, which saves time. Note that cutlass also only uses fence.proxy.async with this modifier: https://github.com/NVIDIA/cutlass/blob/9baa06dd57804ce8fb5efe9e471b3451341522c6/include/cutlass/arch/barrier.h#L717

In testing this took our small LLM problem set from 84% to 90% of cublas (ignoring split-K). Note that we should also predicate the fenceAsyncProxy calls to match the predicates in the consumer expressions -- currently meaning in TMA stores. This can also give us a speedup which is measurable in some test problems but I have not yet automated that so I don't know what to expect in terms of overall speedups.

PTX doc: https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar

We only currently use the `kir::FenceAsyncProxy` expression type to
avoid shared memory WARs. Adding this modifier to the PTX instruction
means we don't need to wait for gmem writes to be available to the
epilogue threads, which saves time.
Note that cutlass also only uses `fence.proxy.async` with this modifier:
https://github.com/NVIDIA/cutlass/blob/9baa06dd57804ce8fb5efe9e471b3451341522c6/include/cutlass/arch/barrier.h#L717

In testing this took our small LLM
problem set from 84% to 90% of cublas (ignoring split-K). Note that we
should also predicate the `fenceAsyncProxy` calls to match the
predicates in the consumer expressions -- currently meaning in TMA
stores. This can also give us a speedup but I have not yet automated
that.

PTX doc: https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar
@jacobhinkle
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jul 20, 2025

Review updated until commit 9a8ebe6

Description

  • Modify fence.proxy.async to fence.proxy.async.shared::cta

  • Update utility mapping and toString methods accordingly


Changes walkthrough 📝

Relevant files
Enhancement
inline_ptx.cpp
Update PTX instruction for `FenceAsyncProxy`                         

csrc/device_lower/pass/inline_ptx.cpp

  • Change fence.proxy.async to fence.proxy.async.shared::cta in
    LowerToInlinePtx::handle
  • +1/-1     
    kernel_ir.cpp
    Update utility mapping and `toString` methods                       

    csrc/kernel_ir.cpp

  • Update utility mapping for fence.proxy.async to
    fence.proxy.async.shared::cta
  • Modify toString method in FenceAsyncProxy and WgMmaFence
  • +3/-3     
    kernel_ir.h
    Update comment for `FenceAsyncProxy`                                         

    csrc/kernel_ir.h

    • Update comment for FenceAsyncProxy class
    +1/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The change in the PTX instruction from fence.proxy.async to fence.proxy.async.shared::cta should be validated to ensure it does not introduce any unintended behavior or compatibility issues with existing code.

    "fence.proxy.async.shared::cta",
    Consistency Check

    The string representation of FenceAsyncProxy in toString method should be consistent with the PTX instruction used in inline_ptx.cpp. Ensure that the change from fence.proxy.async to fence.proxy.async.shared::cta is reflected correctly in all relevant places.

    return "fence.proxy.async.shared::cta\n";
    Consistency Check

    The string representation of WgMmaFence in toString method should be consistent with the PTX instruction used in inline_ptx.cpp. Ensure that the change from fence.proxy.async to wgmma.fence.sync.aligned is reflected correctly in all relevant places.

    return "wgmma.fence.sync.aligned\n";

    Comment on lines 653 to 655
    std::string WgMmaFence::toString(int indent_size) const {
    return "fence.proxy.async\n";
    return "wgmma.fence.sync.aligned\n";
    }
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Fixing previously incorrect printout

    @jacobhinkle
    Copy link
    Collaborator Author

    jacobhinkle commented Jul 21, 2025

    I am hitting correctness issues in the pingpong test unless the fenceProxyAsync call is predicated along with the tma store. UPDATE: predicate will be inserted by #4820

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant