Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 13, 2025

This pull request updates the atomic add infrastructure to consistently pass destination arguments by pointer rather than by reference or value. This change ensures correct semantics and compatibility across CUDA, HIP, and the internal IR, and unifies the calling conventions in both C++ and Python code.

Device-side atomic add API changes:

  • Updated all device-side AtomicAdd and AtomicAddRet functions in CUDA and HIP templates to take a pointer (T1 *address) instead of a reference (T1 &ref), and updated all internal calls accordingly. [1] [2] [3]

IR and transformation logic updates:

  • Modified the construction of atomic add calls in AtomicAddNode::MakeSIMTLoop and AtomicAddVectorizeRewriter to use address_of on the destination, ensuring the first argument is always a pointer. [1] [2] [3]

Python frontend alignment:

  • Updated the Python frontend (tilelang/language/atomic.py) to always pass the destination as a pointer using T.address_of(dst) in calls to AtomicAdd and AtomicAddRet, matching the new device signature.

Summary by CodeRabbit

  • Refactor
    • Atomic add operations now use memory addresses (pointers) for the destination parameter across CUDA, HIP, language, and vectorization paths; argument ordering for atomic calls updated to pass the address first.
  • Tests
    • Updated memory-access legalization test to exercise the address-based atomic call.

✏️ Tip: You can customize this high-level summary in your review settings.

…o destination

* Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument.
* Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature.
* Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements.
* Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function.
* Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations.
… pointer

* Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly.
* Simplified the handling of source and destination predicates, improving clarity and maintainability of the code.
* Ensured compatibility with the new pointer-based approach for atomic operations.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 13, 2025

Walkthrough

Change atomic-add call sites and device templates to pass destination addresses (via address_of) and update atomic function signatures to accept pointer parameters (T*) instead of references (T&), adjusting argument ordering for extern atomic calls.

Changes

Cohort / File(s) Summary
Device atomic signature updates
src/tl_templates/cuda/atomic.h, src/tl_templates/hip/common.h
Updated AtomicAdd / AtomicAddRet signatures to take T1 *address instead of T1 &ref; revised internal casts and address usage to operate on provided pointer parameter.
Extern atomic call construction
src/op/atomic_add.cc
Reworked construction of extern atomic add arguments: compute dst_ptr via address_of and push dst_ptr, then src_value, then memory_order; removed conditional masking of src/dst values.
Vectorize transform adjustments
src/transform/atomicadd_vectorize.cc
Scalar/vectorization paths updated to supply destination as an address: push address_of_dst for scalar lanes; in fallback, convert first arg BufferLoad to address_of(...) or forward existing address for CallNode, preserving remaining args.
TileLang callsites
tilelang/language/atomic.py, testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
Change simple atomic_add extern calls to pass T.address_of(dst) (pointer) in both memory-order and non-memory-order variants; updated test to call extern AtomicAdd with address_of(...) instead of value.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Special attention areas:
    • Consistent application of address_of at all call sites (op/atomic_add.cc, atomicadd_vectorize.cc, language/atomic.py, tests)
    • Correct argument ordering for extern atomic calls (dst_ptr, src_value, memory_order)
    • Pointer vs. reference semantics in CUDA/HIP templates and proper reinterpret_cast usage
    • Ensure removal of masking logic doesn't change semantics for out-of-bounds or predicated paths

Possibly related PRs

Suggested reviewers

  • Rachmanino

Poem

🐇 I hopped from ref to pointer bright,
Swapped values for addresses overnight,
I nudged args in tidy rows,
Now atomics add where pointer goes,
A tiny hop, a confident bite.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main change: converting atomic add destination arguments from references to pointers across the codebase.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e60cd6 and 1c1a551.

📒 Files selected for processing (1)
  • testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py (2)
src/tl_templates/cuda/reduce.h (1)
  • T (178-250)
tilelang/language/tir/op.py (1)
  • call_extern (173-195)
🔇 Additional comments (1)
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py (1)

102-104: AtomicAdd test expectation updated correctly (dst passed by pointer).

Switching to T.address_of(A[...]) on Line 104 matches the project's AtomicAdd convention. The extern signature (confirmed in atomicadd_vectorize.cc and atomic operation templates) expects (pointer, value) as arguments, and "handle" is the correct return dtype for the non-returning AtomicAdd variant (as opposed to AtomicAddRet which returns the previous value). The bounds guards on Lines 102–103 are properly preserved.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
src/op/atomic_add.cc (1)

281-288: Nit: comment says tvm_access_ptr but you’re using builtin::address_of().

The code comment (Line 281) references tvm_access_ptr, but the implementation uses builtin::address_of() (Line 282-284). Either rename the comment or (if required by downstream) switch to the intended builtin.

src/tl_templates/hip/common.h (1)

119-122: HIP AtomicAddRet pointer migration looks consistent with the new call convention.

Only small nit: address is already T1*, so reinterpret_cast<T1 *>(address) is redundant unless you expect “pointer-as-integer” call patterns.

src/transform/atomicadd_vectorize.cc (1)

261-284: Fallback path may still pass a non-address first arg (e.g., IfThenElse) under some producers.

The fallback only wraps BufferLoad into address_of(...). If node->args[0] can be an IfThenElse (or other wrappers that evaluate to a BufferLoad/value), the new pointer-first contract could be violated.

If those forms are still possible, consider handling IfThenElseNode by producing an IfThenElse(cond, address_of(then_load), address_of(else_load)) (when both branches are BufferLoad) or forcing earlier canonicalization.

src/tl_templates/cuda/atomic.h (1)

172-250: AtomicAdd pointer migration is good; consider removing uninitialized reads before inline-asm outputs.

In the half/bf16 inline-PTX paths, ret_val is declared and then used to initialize ret_val_cast via a read (e.g., *reinterpret_cast<unsigned short*>(&ret_val)). That’s an uninitialized read in C++ even if the asm output overwrites ret_val_cast. Prefer initializing the scalar directly:

-        __half ret_val;
-        unsigned short ret_val_cast =
-            *reinterpret_cast<unsigned short *>(&ret_val);
+        unsigned short ret_val_cast = 0;

(and similarly for bf16).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2905143 and 0e60cd6.

📒 Files selected for processing (5)
  • src/op/atomic_add.cc (1 hunks)
  • src/tl_templates/cuda/atomic.h (2 hunks)
  • src/tl_templates/hip/common.h (1 hunks)
  • src/transform/atomicadd_vectorize.cc (2 hunks)
  • tilelang/language/atomic.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/op/atomic_add.cc (1)
src/op/copy.cc (2)
  • MakePredicate (325-354)
  • MakePredicate (325-327)
tilelang/language/atomic.py (1)
tilelang/language/tir/op.py (2)
  • call_extern (173-195)
  • address_of (464-480)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (4)
src/transform/atomicadd_vectorize.cc (1)

249-255: Scalar vectorize path: AtomicAdd(address_of(dst), value, memory_order) is the right layout.

This aligns with the new pointer-based AtomicAdd signature.

src/tl_templates/cuda/atomic.h (1)

252-328: AtomicAddRet(T1* address, ...) update looks consistent with callsites passing pointers.

The pointer-first API matches the updated IR/frontend; the remaining implementation still uses *address for cuda::atomic_ref and raw pointer for atomicAdd, which is coherent.

src/op/atomic_add.cc (1)

270-293: [Rewritten review comment]
[Classification tag]

tilelang/language/atomic.py (1)

175-193: The proposed fix uses non-existent TVM API functions and misunderstands how let-bindings work in TVM expressions; however, there is a real underlying issue: the scalar path attempts to call address_of() on non-buffer types where it cannot work.

The scalar branch (line 178) is taken when dst_extent is None and src_extent is None, meaning neither dst nor value is a BufferLoad or BufferRegion. However, tvm.tir.address_of() only works on specific expression types (like BufferLoad), not plain Var objects. Testing confirms that address_of(Var) raises an exception.

The review's proposed solution is flawed:

  • T.has_let_value() and T.get_let_value() do not exist in tvm.tir
  • In TVM, let-bindings are statement-level constructs (LetStmt), not properties of Var objects
  • A Var doesn't "contain" a let-value that can be resolved

The actual bug: if dst is a plain Var (when get_extent returns None), the subsequent T.address_of(dst) call will fail. Either this scalar path should not be reachable with non-buffer types, or the approach needs to be fundamentally different from using address_of.

@LeiWang1999
Copy link
Member Author

Local test can pass, merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant