Skip to content

[BugFix] Robust atomic operations using PTX and atomicCAS loop#1264

Closed
tzj-fxz wants to merge 4 commits intotile-ai:mainfrom
tzj-fxz:pr1260
Closed

[BugFix] Robust atomic operations using PTX and atomicCAS loop#1264
tzj-fxz wants to merge 4 commits intotile-ai:mainfrom
tzj-fxz:pr1260

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Nov 16, 2025

  • Add the missing implementation of atomicMax/Min in CUDA by atomicCAS loop.
  • PTX version atomicAdd for single 16-byte data
  • Modify the test script

Reference to PR (#1260)

Summary by CodeRabbit

  • New Features

    • Improved atomic operation support for half-precision and bfloat16 with broader memory-order semantics and more reliable behavior.
  • Tests

    • Expanded atomic operation test coverage to include float16 and bfloat16, added explicit test entry points for atomic ops and memory-order scenarios.

KevinZeng08 and others added 3 commits November 14, 2025 14:53
- New implementation for atomicMax and atomicMin using atomicCAS
- PTX version atomicAdd for single 16-byte data
- Modify the test cases
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 16, 2025

Walkthrough

This PR renames loop indices in a backward-pass example, implements CAS- and PTX-based atomic routines for half/bf16 with memory-order handling in CUDA templates, and adds/rewrites Python tests to expand atomic-operation coverage across dtypes and memory orders.

Changes

Cohort / File(s) Summary
Example backward pass
examples/deepseek_nsa/example_tilelang_nsa_bwd.py
Renamed outer loop variable(s) from i to k and adjusted nested indices (_i, _j) and corresponding index usages throughout parallel loops; computation logic preserved.
CUDA atomic implementations
src/tl_templates/cuda/atomic.h
Replaced relaxed-only and missing-path handling for half/__nv_bfloat16 with CAS-loop implementations for AtomicMax/Min variants and added memory-order-aware PTX inline (and CAS fallbacks) for AtomicAdd/AtomicAddRet; retained cuda::atomic_ref fallback for other types.
Atomic tests (Python)
testing/python/language/test_tilelang_language_atomic_add.py
Added explicit no-arg test wrappers (test_atomic_add, test_atomic_max, test_atomic_min, test_atomic_load_store, test_atomic_memory_order, test_atomic_addx2); expanded dtype coverage to include float16 and bfloat16; adjusted memory_order arguments in atomic_different_memory_orders_program.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Kernel
    participant atomic_h as atomic.h
    participant CAS as CAS loop (16-bit)
    participant PTX as PTX inline atomics
    participant atomic_ref as cuda::atomic_ref

    Kernel->>atomic_h: call AtomicAdd/Max/Min(dtype, mem_order)
    alt dtype is half or bf16
        atomic_h->>atomic_h: check memory_order
        alt mem_order is relaxed
            atomic_h->>atomic_ref: use relaxed atomicAdd via atomicRef
            atomic_ref-->>Kernel: return old/new value
        else mem_order non-relaxed
            atomic_h->>PTX: use inline PTX for add (if available)
            PTX-->>Kernel: return value
            note right of PTX `#DDDDFF`: fallback to CAS if needed
            atomic_h->>CAS: perform 16-bit CAS loop for max/min or fallback
            CAS-->>Kernel: return updated value
        end
    else
        atomic_h->>atomic_ref: use cuda::atomic_ref path
        atomic_ref-->>Kernel: return value
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Attention points:
    • src/tl_templates/cuda/atomic.h: correctness of CAS loop bitcasting, ordering semantics, and PTX inline assembly across memory orders.
    • Tests in testing/python/..._atomic_add.py: ensure memory_order permutations and added dtype coverage exercise intended branches and do not introduce flaky behavior.
    • Verify reinterpret-cast/return-value extraction for 16-bit types matches caller expectations.

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • Rachmanino

Poem

🐰 I hopped through code at break of day,
Swapped i for k and kept order sway,
CAS and PTX stitched tiny bits tight,
Tests now hop in float16 light,
A rabbit cheers — atoms take flight! ✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding robust atomic operations using PTX and atomicCAS loop, which aligns with the PR objectives of fixing missing atomicMax/Min implementations and PTX-based atomicAdd.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tzj-fxz tzj-fxz changed the title [BugFix] Atomic operations [BugFix] Robust atomic operations using PTX and atomicCAS loop Nov 16, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)

206-221: H is undefined in both backward kernels and will fail at JIT time

H is used in:

  • Line 206: print("NV", NV, "NS", NS, "B", B, "H", H)
  • Line 220: with T.Kernel(NV, NS, B * H, ...)
  • Line 239: i_b, i_h = i_bh // H, i_bh % H
  • Line 387: with T.Kernel(NV, NS, B * H, ...)
  • Line 407: i_b, i_h = i_bh // H, i_bh % H

but never defined in tilelang_kernel_bwd_dkv or tilelang_kernel_bwd_dqkv. This will raise a NameError when the JIT builds these kernels.

You likely meant to set H = heads_kv (or derive it from heads) before using it, e.g.:

-    heads_kv = heads // groups
+    heads_kv = heads // groups
+    H = heads_kv

and then rely consistently on H in the kernel launch and index unpacking.

Also applies to: 387-408

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 729e66c and f6280bb.

📒 Files selected for processing (3)
  • examples/deepseek_nsa/example_tilelang_nsa_bwd.py (3 hunks)
  • src/tl_templates/cuda/atomic.h (6 hunks)
  • testing/python/language/test_tilelang_language_atomic_add.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)
tilelang/language/loop.py (1)
  • Parallel (12-32)
testing/python/language/test_tilelang_language_atomic_add.py (1)
tilelang/language/atomic.py (3)
  • atomic_add (116-235)
  • atomic_max (22-65)
  • atomic_min (68-113)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)

108-112: Index renames and T.Parallel usage look correct

The switch from i to k / _i, _j in these T.Parallel loops is consistent with the fragment shapes ([G, BS], [G], [G, BV]) and avoids shadowing the outer i loop variable. Semantics of the softmax/rescale and backward steps remain unchanged.

Also applies to: 127-133, 137-138, 468-469

testing/python/language/test_tilelang_language_atomic_add.py (2)

239-261: New top-level atomic tests are straightforward and parameter choices look reasonable

The new test_atomic_add, test_atomic_max, test_atomic_min, test_atomic_load_store, test_atomic_memory_order, and test_atomic_addx2 wrappers just delegate to the existing run_* helpers with modest matrix sizes. This keeps the pytest surface clean while fully exercising the kernels without changing behavior.


263-279: Review comment is based on incorrect premise; atomic_min code is correct

The reviewer claims atomic_min has a duplicated func_name argument in the call_extern signature:

T.call_extern(return_type, func_name, func_name, dst, value, ...)

However, the actual code shows:

T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])

This is identical to atomic_max, atomic_add (scalar path), atomic_load, and atomic_store—all of which follow the same pattern. There is no duplicated argument. The memory-order path for atomic_min is correct and consistent with all other atomic operations in the module.

The test expansion (float/float16/bfloat16) will exercise atomic_min with memory ordering, but no pre-existing bug exists in the language layer.

Likely an incorrect or invalid review comment.

Comment on lines 49 to 66
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
// There is no implementation of atomicMax for half and bf16 in cuda.
// We simulate this process by atomicCAS loop.
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort =
*reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Avoid reinterpreting arbitrary T2 as unsigned short in half/bf16 CAS loops

In the half/bf16 branches of AtomicMax/AtomicMaxRet/AtomicMin/AtomicMinRet you do:

unsigned short val_as_ushort =
    *reinterpret_cast<unsigned short *>(&val);
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
  ...
}

This assumes T2 is a 16‑bit type with the same layout as T1/NT1. The templates do not enforce that, so any future instantiation like AtomicMax<half>(..., float) would result in reading the low 16 bits of a float, which is both incorrect and undefined‑behavior territory.

A safer pattern is to normalize to NT1 first and only then bit‑cast:

-  unsigned short val_as_ushort =
-      *reinterpret_cast<unsigned short *>(&val);
+  NT1 val_nt = cuda_cast<NT1>(val);
+  unsigned short val_as_ushort =
+      *reinterpret_cast<unsigned short *>(&val_nt);

-  while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
+  while (val_nt > *reinterpret_cast<NT1 *>(&old_val_ushort)) {

and similarly in the *_Ret and Min variants.

This keeps the templates robust if the call sites ever pass a different arithmetic type while still matching the current use where T1 and T2 have the same dtype.

Also applies to: 77-92, 105-122, 133-149

🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 49-66 (and likewise at 77-92,
105-122, 133-149) the code unsafely reinterprets the incoming template value
`val` (T2) as `unsigned short` and compares by reinterpreting old ushort bits
directly to T1; instead, first convert/normalize `val` to the native 16-bit CUDA
type NT1 (e.g., half or __nv_bfloat16) and only then bit-cast that NT1 value to
`unsigned short` for the CAS loop, and when comparing use a cast of the old
ushort to NT1 for the numeric comparison; update all four branches
(Max/MaxRet/Min/MinRet) to follow this pattern so the templates remain correct
if T2 differs from T1/NT1.

Comment on lines 161 to 222
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
// Since atomic ref do not support memory order, we need to inline ptx code here for each situation
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast = *reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address = reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile(
"atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile(
"atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile(
"atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast = *reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address = reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile(
"atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile(
"atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile(
"atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Half/bf16 AtomicAdd PTX path: semantics are reasonable, but please verify PTX mnemonics and document memory-order behavior

The new half/bf16 AtomicAdd/AtomicAddRet logic:

  • Uses native atomicAdd for memory_order_relaxed.
  • For non‑relaxed orders, emits inline PTX (atom.*.gpu.global.add.noftz.f16 / .bf16) to implement release/consume, acquire, and acq_rel/seq_cst variants.

Two points worth tightening up:

  1. PTX opcode/operand verification

    The instruction strings and constraints ("l" for the 64‑bit address, "h" for 16‑bit values) look consistent, but they are tightly coupled to the targeted PTX version and SM architecture. It would be good to double‑check them against the CUDA/PTX docs you target (and maybe add a brief comment with the minimum SM/PTX version).

  2. Memory-order consistency across operations

    • AtomicAdd/AtomicAddRet for half/bf16 now respect the memory_order argument (via the PTX variants).
    • AtomicMax/AtomicMin for half/bf16 use an atomicCAS loop and ignore memory_order. That effectively gives you a strong (seq_cst‑like) ordering even when the caller requests a weaker order.

    That’s functionally safe but slightly asymmetric with the float path and the new Add behavior. It might be worth documenting this (or later extending Max/Min to use the PTX memory‑order variants as well if you need the weaker semantics for performance).

Also applies to: 234-297


🌐 Web query:

Check the current PTX ISA documentation for the correct syntax and availability of `atom.release.gpu.global.add.noftz.f16` and `.bf16` (and their acquire/acq_rel variants) on Hopper/Lovelace-class GPUs.

💡 Result:

Short answer

  • The PTX grammar supports the full form atom{.sem}{.scope}{.space}.add.noftz. (e.g. atom.release.gpu.global.add.noftz.f16) and the .sem qualifiers .acquire, .release and .acq_rel. Syntax is the red/atom form: atom{.sem}{.scope}{.global}.add.noftz.f16 / .bf16 (and vector variants) with .scope ∈ { .cta, .gpu, .sys } and .sem ∈ { .relaxed, .acquire, .release, .acq_rel }. (docs.nvidia.com)

Availability on Hopper vs Lovelace

  • f16 (atom.add.noftz.f16): supported in PTX and enabled on modern targets (requires earlier SM levels for .sem support); .sem qualifiers are available on SM70+. Thus atom.release.gpu.global.add.noftz.f16 (and its .acquire/.acq_rel variants) are valid and usable on both Hopper (sm_90) and Lovelace (sm_89). (docs.nvidia.com)
  • bf16 (atom.add.noftz.bf16): PTX introduced atom.add.noftz.bf16 / bf16x2 in PTX 7.8 and the target notes require SM90 for bf16 atom support. That means bf16 atom instructions (and their .sem/.scope variants) are available on Hopper (Compute Capability 9.0 / sm_90) but are not guaranteed on Ada/Lovelace class GPUs that report sm_89 (they lack the sm_90 bf16 atom requirement). (docs.nvidia.com)

Practical takeaway

  • Use atom.release.gpu.global.add.noftz.f16 (and .acquire/.acq_rel) on both Lovelace and Hopper.
  • Use atom.release.gpu.global.add.noftz.bf16 (and .acquire/.acq_rel) only when targeting sm_90/Hopper (or later); do not expect it on sm_89/Lovelace. (docs.nvidia.com)

If you want, I can quote the exact PTX grammar lines and point to the specific PTX manual page versions (PTX 8.x / PTX 7.8) you should target.

Citations:


I'll now verify the codebase to determine if bf16 atomic operations are properly guarded for SM capability, since the web search revealed a critical availability constraint.

Now let me inspect the codebase to check for any SM capability guards:

Based on the web search confirmation that bf16 global atomic adds are supported on Hopper, but for non-Hopper targets an atomic compare-and-swap pattern is needed as fallback, let me inspect the actual code file to verify if these guards are in place:

Now, generate the final rewritten review comment:


Critical: bf16 atomic operations require SM90 guard; not guaranteed on Lovelace (sm_89)

BF16 global atomic adds are supported on Hopper, but require fallback handling (atomic CAS loop) for non-Hopper targets. The current code at lines 161–222 and 234–297 uses inline PTX atom.add.noftz.bf16 instructions unconditionally, which will fail or produce incorrect results on Ada Lovelace GPUs (sm_89).

Required fix:

  • Wrap bf16 atomic operations in #if __CUDA_ARCH__ >= 900 guards (or equivalent SM version check).
  • Provide a fallback implementation using atomicCAS loop for sm_89 and earlier, or document that bf16 atomics are Hopper-only.

PTX opcode verification: The syntax (atom.release/acquire/acq_rel.gpu.global.add.noftz.f16/.bf16) is correct per PTX ISA docs, and constraint strings ("l" for 64-bit address, "h" for 16-bit values) are sound. However, f16 atomics with memory-order support are universally available (SM70+), while bf16 atomics are SM90-exclusive.


🏁 Script executed:

#!/bin/bash
# Check for __CUDA_ARCH__ or SM capability guards around bf16 atomic code
head -n 300 src/tl_templates/cuda/atomic.h | tail -n 150

Length of output: 6794


Critical: bf16 atomic PTX instructions require SM90 guard; will fail on Lovelace (sm_89)

BF16 atomic operations with memory-order support are only available on Hopper (sm_90) and later, but the code at lines 161–222 (AtomicAdd/AtomicAddRet) uses atom.release/acquire/acq_rel.gpu.global.add.noftz.bf16 without SM capability guards. The code will fail at runtime or produce incorrect PTX assembly when targeting Ada/Lovelace GPUs (sm_89) with non-relaxed memory orders.

Required fixes:

  1. Lines 161–222: Wrap the bf16 non-relaxed memory-order branches (the inline PTX atom.* instructions) in #if __CUDA_ARCH__ >= 900 guards, or add a fallback using atomicCAS-loop for sm_89.

  2. Lines 234–297: Apply the same SM90 guard to bf16 atomic code in AtomicMax/AtomicMin (mentioned in "also applies to").

Note: F16 atomics with memory-order support are available on SM70+, so the atom.*.f16 code paths are safe across all targeted architectures and require no guards.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 161–222, the bf16 inline PTX
atom.*.bf16 instructions used for non-relaxed memory orders are emitted
unconditionally and will fail on architectures older than SM90 (e.g.
Ada/Lovelace sm_89); wrap those bf16 non-relaxed memory-order branches in a CUDA
arch guard (#if __CUDA_ARCH__ >= 900) or replace them with an atomicCAS-based
fallback loop for <=899; apply the same fix to the bf16 atomic code paths
referenced around lines 234–297 (AtomicMax/AtomicMin) so bf16 memory-ordered
atomics are only emitted on SM90+ or use the CAS fallback on older SMs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/tl_templates/cuda/atomic.h (2)

49-64: Fix T2 bit-casting and compare in half/bf16 CAS-based Max/Min

The half/bf16 CAS loops still:

  • Bit-cast arbitrary T2 val to unsigned short (reinterpret_cast<unsigned short *>(&val)), which is only valid if T2 is a 16‑bit type with the same layout as NT1. Any instantiation like AtomicMax<half_t, float> is undefined and numerically wrong.
  • Compare val (T2) against *reinterpret_cast<T1 *>(&old_val_ushort), mixing types and not using the normalized 16‑bit CUDA type NT1. This can make Max/Min semantics differ from the relaxed path that goes through cuda_cast<NT1>.

You can keep the CAS approach but normalize through NT1 first and compare in NT1:

-  if constexpr (std::is_same_v<NT1, half> ||
-                std::is_same_v<NT1, __nv_bfloat16>) {
-    // There is no implementation of atomicMax for half and bf16 in cuda.
-    // We simulate this process by atomicCAS loop.
-    unsigned short *address_as_ushort =
-        reinterpret_cast<unsigned short *>(address);
-    unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
-    unsigned short old_val_ushort = *address_as_ushort;
-    while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
+  if constexpr (std::is_same_v<NT1, half> ||
+                std::is_same_v<NT1, __nv_bfloat16>) {
+    NT1 val_nt = cuda_cast<NT1>(val);
+    unsigned short *address_as_ushort =
+        reinterpret_cast<unsigned short *>(address);
+    unsigned short val_as_ushort =
+        *reinterpret_cast<unsigned short *>(&val_nt);
+    unsigned short old_val_ushort = *address_as_ushort;
+    while (val_nt > *reinterpret_cast<NT1 *>(&old_val_ushort)) {
       unsigned short assumed_val_ushort = old_val_ushort;
       old_val_ushort =
           atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
       if (assumed_val_ushort == old_val_ushort) {
         break;
       }
     }
   }

and in the *_Ret variants return via the normalized type:

-    return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
+    return cuda_cast<T1>(*reinterpret_cast<NT1 *>(&old_val_ushort));

Apply the same pattern to AtomicMaxRet, AtomicMin, and AtomicMinRet to keep all four half/bf16 CAS loops consistent.

Also applies to: 76-90, 103-118, 130-144


190-215: Guard bf16 atom.*.bf16 PTX paths for SM90+ and add a fallback

The bf16 non‑relaxed branches in AtomicAdd/AtomicAddRet:

asm volatile("atom.release.gpu.global.add.noftz.bf16 ...");
...
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 ...");
...
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 ...");

are emitted unconditionally whenever NT1 == __nv_bfloat16. PTX bf16 global atomics are only guaranteed on SM90+; targeting e.g. Ada/Lovelace (sm_89) will fail assembly or runtime.

You should:

  • Guard these bf16 PTX code paths with an SM90‑level check, and
  • Provide a fallback (e.g. relaxed atomicAdd on __nv_bfloat16) on older architectures.

For example:

-      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
-        // bf16
-        ...
-        if (memory_order == int(cuda::memory_order_release) ||
-            memory_order == int(cuda::memory_order_consume)) {
-          asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
-                       : "=h"(ret_val_cast)
-                       : "l"(ref_address), "h"(val_cast)
-                       : "memory");
-        } else if (memory_order == int(cuda::memory_order_acquire)) {
-          ...
-        } else if (memory_order == int(cuda::memory_order_acq_rel) ||
-                   memory_order == int(cuda::memory_order_seq_cst)) {
-          ...
-        }
-      }
+      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
+        // bf16 atomics with memory-order support (SM90+)
+        ...
+        if (memory_order == int(cuda::memory_order_release) ||
+            memory_order == int(cuda::memory_order_consume)) {
+          asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
+                       : "=h"(ret_val_cast)
+                       : "l"(ref_address), "h"(val_cast)
+                       : "memory");
+        } else if (memory_order == int(cuda::memory_order_acquire)) {
+          ...
+        } else if (memory_order == int(cuda::memory_order_acq_rel) ||
+                   memory_order == int(cuda::memory_order_seq_cst)) {
+          ...
+        }
+#else
+        // Fallback: no bf16 memory-ordered atomics before SM90; degrade to relaxed
+        atomicAdd(reinterpret_cast<NT1 *>(address), val_nt);
+#endif
+      }

Apply the same guard/fallback pattern to the bf16 block in AtomicAddRet, and consider mirroring it in the vectorized bf16 AtomicAddx2/AtomicAddx2Ret further down, which also use atom.*.v2.bf16.

Also applies to: 262-289

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f6280bb and d0da460.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/atomic.h (6 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/tl_templates/cuda/atomic.h
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/tl_templates/cuda/atomic.h
🧬 Code graph analysis (1)
src/tl_templates/cuda/atomic.h (1)
src/tl_templates/cpp/half.hpp (22)
  • half (2476-2479)
  • half (2476-2476)
  • half (2557-2557)
  • half (2562-2562)
  • half (3001-3001)
  • half (3006-3008)
  • half (3235-3237)
  • half (3243-3243)
  • half (3417-3425)
  • half (3432-3440)
  • half (5249-5251)
  • int (760-765)
  • int (772-779)
  • int (787-797)
  • int (804-811)
  • int (816-821)
  • int (827-832)
  • int (838-843)
  • int (855-864)
  • int (872-879)
  • int (892-914)
  • int (5267-5273)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)

Comment on lines +157 to +217
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
// Since atomic ref do not support memory order, we need to inline ptx
// code here for each situation
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Normalize to NT1 before bit-casting in half/bf16 AtomicAdd / AtomicAddRet

In the half/bf16 branches of AtomicAdd/AtomicAddRet:

  • The relaxed path correctly does atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));.
  • The non‑relaxed PTX paths instead do unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);, reinterpreting arbitrary T2 as a 16‑bit payload.

For T2 != NT1 (e.g. T2=float), relaxed vs non‑relaxed orders will produce different numerical results and the PTX path is UB.

Normalize through NT1 once and reuse it:

-  if constexpr (std::is_same_v<NT1, half> ||
-                std::is_same_v<NT1, __nv_bfloat16>) {
-    if (memory_order == int(cuda::memory_order_relaxed)) {
-      atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
+  if constexpr (std::is_same_v<NT1, half> ||
+                std::is_same_v<NT1, __nv_bfloat16>) {
+    NT1 val_nt = cuda_cast<NT1>(val);
+    if (memory_order == int(cuda::memory_order_relaxed)) {
+      atomicAdd(reinterpret_cast<NT1 *>(address), val_nt);
     } else {
       ...
-        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
+        unsigned short val_cast =
+            *reinterpret_cast<unsigned short *>(&val_nt);
       ...
-      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
+      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
         ...
-        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
+        unsigned short val_cast =
+            *reinterpret_cast<unsigned short *>(&val_nt);
         ...
       }
     }
   }

and similarly in AtomicAddRet:

-  if constexpr (std::is_same_v<NT1, half> ||
-                std::is_same_v<NT1, __nv_bfloat16>) {
-    if (memory_order == int(cuda::memory_order_relaxed)) {
-      return static_cast<T1>(
-          atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
+  if constexpr (std::is_same_v<NT1, half> ||
+                std::is_same_v<NT1, __nv_bfloat16>) {
+    NT1 val_nt = cuda_cast<NT1>(val);
+    if (memory_order == int(cuda::memory_order_relaxed)) {
+      return static_cast<T1>(
+          atomicAdd(reinterpret_cast<NT1 *>(address), val_nt));
     } else {
       ...
-        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
+        unsigned short val_cast =
+            *reinterpret_cast<unsigned short *>(&val_nt);
       ...
-      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
+      } else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
         ...
-        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
+        unsigned short val_cast =
+            *reinterpret_cast<unsigned short *>(&val_nt);
         ...
       }
     }
   }

This keeps all memory-order variants using the same numeric value and avoids UB from reinterpreting T2.

Also applies to: 229-291


I'll verify the review comment by inspecting the actual code and checking key claims.


---

🏁 Script executed:

```shell
#!/bin/bash
# Locate the atomic.h file
fd -t f "atomic\.h" --path "*cuda*"

Length of output: 233


🏁 Script executed:

#!/bin/bash
# Inspect the exact code at lines 157-217 in atomic.h
find . -name "atomic.h" -path "*cuda*" | head -1 | xargs -I {} sed -n '157,217p' {}

Length of output: 3101


🏁 Script executed:

#!/bin/bash
# Search for cuda_cast function definition
rg -n "cuda_cast" --type cpp

Length of output: 1093


🏁 Script executed:

#!/bin/bash
# Inspect lines 229-291 as well
find . -name "atomic.h" -path "*cuda*" | head -1 | xargs -I {} sed -n '229,291p' {}

Length of output: 3193


🏁 Script executed:

#!/bin/bash
# Get context around the AtomicAdd and AtomicAddRet function signatures
find . -name "atomic.h" -path "*cuda*" | head -1 | xargs -I {} sed -n '140,160p' {}

Length of output: 810


Normalize to NT1 before bit-casting in half/bf16 AtomicAdd / AtomicAddRet

In the half/bf16 branches of AtomicAdd/AtomicAddRet:

  • The relaxed path correctly does static_cast<NT1>(val) to normalize the value.
  • The non‑relaxed PTX paths instead do unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val), reinterpreting arbitrary T2 directly as 16‑bit bytes.

For T2 ≠ NT1 (e.g., T2=float, NT1=half), this causes:

  • Relaxed path: static_cast<half>(float) correctly converts the value
  • PTX paths: reinterprets raw float bytes as unsigned short, producing UB and inconsistent results

Normalize through NT1 once and reuse across all memory orders. Use cuda_cast<NT1>(val) to align with the pattern already applied elsewhere in the file (lines 67, 94, 121, 148, 220, 295):

   if constexpr (std::is_same_v<NT1, half> ||
                 std::is_same_v<NT1, __nv_bfloat16>) {
+    NT1 val_nt = cuda_cast<NT1>(val);
     if (memory_order == int(cuda::memory_order_relaxed)) {
-      atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
+      atomicAdd(reinterpret_cast<NT1 *>(address), val_nt);
       ...
-        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
+        unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val_nt);

Apply the same fix to both AtomicAdd (lines 157–217) and AtomicAddRet (lines 229–291), in both the half and __nv_bfloat16 branches.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/atomic.h around lines 157–217 (and likewise apply to
AtomicAddRet around 229–291), the PTX non‑relaxed paths are reinterpreting the
caller T2 value bytes directly as a 16‑bit half/bf16, which is UB when T2 !=
NT1; instead, normalize the incoming value to NT1 first (use cuda_cast<NT1>(val)
as used elsewhere) and then bit‑cast that normalized NT1 to unsigned short for
ret_val_cast and val_cast before emitting PTX. Replace the direct
reinterpret_cast of &val with a reinterpret_cast of the normalized
cuda_cast<NT1>(val) (store into a temporary NT1 variable), and reuse that
normalized value across all memory_order branches in both half and __nv_bfloat16
sections.

@tzj-fxz tzj-fxz closed this Nov 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants