Skip to content

[BugFix] Remove memory_order in atomic constexpr and fix NSA bwd#1260

Merged
LeiWang1999 merged 4 commits intotile-ai:mainfrom
KevinZeng08:fix_nsa_and_atomic
Nov 16, 2025
Merged

[BugFix] Remove memory_order in atomic constexpr and fix NSA bwd#1260
LeiWang1999 merged 4 commits intotile-ai:mainfrom
KevinZeng08:fix_nsa_and_atomic

Conversation

@KevinZeng08
Copy link
Contributor

@KevinZeng08 KevinZeng08 commented Nov 15, 2025

What this PR do?

TODO in future

  • atomic_ref does not support 1 or 2 byte types, link, maybe add inline PTX for non-vectorized atomic when dtype is fp16/bf16 for different memory orders
  • refactor NSA example and benchmark, since it contains some redundant code

Summary by CodeRabbit

  • Refactor

    • Clarified internal loop variable naming to reduce shadowing in parallel computations.
  • Performance

    • Standardized CAS-based handling for half/bfloat16 atomics and refined memory‑order paths to streamline specialized atomic operations.
  • Tests

    • Converted several JIT-wrapped checks to plain tests, expanded dtype coverage (float/float16/bfloat16), and adjusted memory‑order scenarios to improve test coverage.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 15, 2025

Walkthrough

Renamed loop indices in an NSA backward example to avoid shadowing, adjusted CUDA atomic selection to always use CAS/primitive paths for half and __nv_bfloat16 (regardless of memory_order), and promoted several atomic tests from JIT-decorated wrappers to plain pytest functions with expanded dtype coverage.

Changes

Cohort / File(s) Summary
Loop variable renaming
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
Renamed internal loop iterators (e.g., ik, and inner i_i/_j) to avoid shadowing; algorithmic behavior and math unchanged.
CUDA atomic selection changes
src/tl_templates/cuda/atomic.h
Removed memory_order_relaxed gating for type-specialized CAS/primitive paths for half and __nv_bfloat16; AtomicMax/Min/… now use CAS-loop for those types regardless of memory_order. Non-special types still use cuda::atomic_ref with provided memory order. AtomicAdd/AtomicAddRet keep primitive fast-path for relaxed, and inline PTX/CAS for other orders.
Test promotion & adjustments
testing/python/language/test_tilelang_language_atomic_add.py
Converted several JIT-decorated test wrappers into plain test_* functions (test_atomic_add, test_atomic_max, test_atomic_min, test_atomic_load_store, test_atomic_memory_order, test_atomic_addx2), adjusted memory-order argument order in one helper, and expanded dtype coverage in test_atomic_different_memory_orders.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant atomic.h

  rect rgb(240,248,255)
  Note right of atomic.h: Old selection logic
  Caller->>atomic.h: atomic_op(type T, order M)
  atomic.h-->>atomic.h: if T is half/bf16 AND M == memory_order_relaxed
  atomic.h-->>atomic.h:   -> use primitive/CAS fast-path
  atomic.h-->>atomic.h: else
  atomic.h-->>Caller: use cuda::atomic_ref with order M
  end

  rect rgb(245,255,240)
  Note right of atomic.h: New selection logic
  Caller->>atomic.h: atomic_op(type T, order M)
  atomic.h-->>atomic.h: if T is half/bf16
  atomic.h-->>atomic.h:   -> use primitive/CAS fast-path (regardless of M)
  atomic.h-->>atomic.h: else
  atomic.h-->>Caller: use cuda::atomic_ref with order M
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Attention points:
    • src/tl_templates/cuda/atomic.h: verify all atomic variants (Max/Min/Add and their Ret variants) were updated consistently and that memory-order semantics are preserved for non-special types; check inline PTX sequences for ordered stores/loads.
    • testing/.../test_tilelang_language_atomic_add.py: ensure promoted tests run under CI test harness (no missing fixtures) and that expanded dtype cases are correct.
    • examples/deepseek_nsa/example_tilelang_nsa_bwd.py: confirm renamed iterators do not shadow other names or affect parallelization pragmas.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐇 I hopped through loops and renamed with care,
_i became _k, inner names stripped bare.
Atomics now pick paths by type, not fate,
Tests awoke, more dtypes on their plate.
A rabbit's tidy tweak — concise and fair. ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[BugFix] Remove memory_order in atomic constexpr and fix NSA bwd' accurately summarizes the main changes: removal of memory_order in atomic operations and fixes to the NSA backward example.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide custom instructions to shape the summary (bullet lists, tables, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example:

"Create a concise high-level summary as a bullet-point list. Then include a Markdown table showing lines added and removed by each contributing author."


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@LeiWang1999
Copy link
Member

cc @tzj-fxz @Rachmanino

@tzj-fxz
Copy link
Contributor

tzj-fxz commented Nov 15, 2025

@KevinZeng08 Thank you for your valuable suggestion. It will be more reliable to directly inline PTX for atomic operations with different memory orders, like what has already been done in vectorized dtype. I will check it and fix it later.

@KevinZeng08
Copy link
Contributor Author

Thanks for your kind reply. Looking forward to your implementation.

- New implementation for atomicMax and atomicMin using atomicCAS
- PTX version atomicAdd for single 16-byte data
- Modify the test cases
@tzj-fxz
Copy link
Contributor

tzj-fxz commented Nov 16, 2025

@KevinZeng08 This new PR (#1264) was created based on your modification and suggestion. For the AtomicAdd operator, we inline PTX to handle different memory orders. For the AtomicMin/Max operators, we implement the atomicCAS loop to simulate the process because the original CUDA does not support a single 16-byte AtomicMin/Max. Could you check it if you are interested? :)

@KevinZeng08
Copy link
Contributor Author

Sure, let me have a check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/tl_templates/cuda/atomic.h (1)

162-189: Clarify comment and verify unused variable handling.

Line 162 states "Since atomic ref do not support memory order," but this is misleading. Looking at lines 66-67 and 220, cuda::atomic_ref does support memory_order for other types. The actual reason for inline PTX is that atomic_ref doesn't support 1- or 2-byte types (half/bf16), as noted in the PR objectives.

Additionally, ret_val is declared at line 166 but never used in the void-returning AtomicAdd function. While the PTX assembly likely doesn't generate unnecessary code, this may be confusing for readers.

Consider:

  1. Updating the comment to clarify that PTX is needed because atomic_ref doesn't support half/bf16 types
  2. Removing the unused ret_val variable or adding a comment explaining why it's present (if required by PTX syntax)
testing/python/language/test_tilelang_language_atomic_add.py (1)

275-277: Consider expanding memory order coverage.

The memory orders were changed from relaxed/acquire/release to release/relaxed/relaxed. While this exercises the new PTX paths for atomic_add with release ordering, it reduces coverage for atomic_max and atomic_min, which now only test relaxed ordering.

Consider adding additional test cases to cover non-relaxed memory orders for atomic_max and atomic_min to ensure the CAS loop works correctly across all memory ordering semantics, even though the current CAS implementation doesn't distinguish between memory orders.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2062cef and d0da460.

📒 Files selected for processing (2)
  • src/tl_templates/cuda/atomic.h (6 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (3 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
🧬 Code graph analysis (2)
src/tl_templates/cuda/atomic.h (1)
src/tl_templates/cpp/half.hpp (22)
  • half (2476-2479)
  • half (2476-2476)
  • half (2557-2557)
  • half (2562-2562)
  • half (3001-3001)
  • half (3006-3008)
  • half (3235-3237)
  • half (3243-3243)
  • half (3417-3425)
  • half (3432-3440)
  • half (5249-5251)
  • int (760-765)
  • int (772-779)
  • int (787-797)
  • int (804-811)
  • int (816-821)
  • int (827-832)
  • int (838-843)
  • int (855-864)
  • int (872-879)
  • int (892-914)
  • int (5267-5273)
testing/python/language/test_tilelang_language_atomic_add.py (1)
tilelang/language/atomic.py (3)
  • atomic_add (116-235)
  • atomic_max (22-65)
  • atomic_min (68-113)
🔇 Additional comments (4)
src/tl_templates/cuda/atomic.h (1)

49-64: Verify NaN handling in float comparison within CAS loop.

The CAS loop compares floating-point values using val > *reinterpret_cast<T1 *>(&old_val_ushort). For half and bfloat16 types, this comparison may not handle NaN values correctly, as NaN comparisons always return false. This could lead to incorrect behavior when NaN values are involved in atomic max operations.

Consider the semantics for NaN handling in your atomic max operations. If IEEE 754 NaN propagation is expected, you may need to add explicit NaN checks.

testing/python/language/test_tilelang_language_atomic_add.py (3)

239-261: LGTM: Test function refactoring.

The explicit test functions improve test discoverability and remove the need for JIT-decorated wrappers. This is a cleaner approach.


263-263: LGTM: Debug path configuration.

Adding debug_root_path to the decorator enables custom debug output location for this specific test program.


364-366: Excellent dtype coverage expansion.

Testing across float, float16, and bfloat16 directly validates the new CAS-loop and PTX inline implementations for half-precision types. This comprehensive coverage is essential for the changes in this PR.

Note: bfloat16 testing requires CUDA compute capability > 7.5 (Turing or newer) as per the preprocessor guards in atomic.h lines 24-28.

@LeiWang1999 LeiWang1999 merged commit 2de566e into tile-ai:main Nov 16, 2025
7 checks passed
@KevinZeng08 KevinZeng08 deleted the fix_nsa_and_atomic branch November 16, 2025 08:30
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…e-ai#1260)

* fix nsa bwd and atomic

* [Lint]

* [BugFix]
- New implementation for atomicMax and atomicMin using atomicCAS
- PTX version atomicAdd for single 16-byte data
- Modify the test cases

* [Lint]

---------

Co-authored-by: tzj-fxz <tzjfxz@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants