Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 2, 2026

This pull request refactors the ptx_cp_async intrinsic throughout the codebase to use access pointer arguments (tvm_access_ptr) instead of raw buffer pointers and offsets. This change modernizes the interface, improves safety, and aligns with best practices for memory access in TVM's TIR. The update affects the intrinsic's Python API, its injection and transformation passes, and related code generation logic.

Key changes include:

API and Usage Changes

  • The ptx_cp_async Python API in tilelang/language/tir/op.py now takes dst_access_ptr, src_access_ptr, and bytes (with an optional predicate), replacing the previous signature that used raw pointers and offsets. The docstring and all usages have been updated accordingly. [1] [2]

Transformation and Injection Passes

  • The PTXAsyncCopyInjector in src/transform/inject_ptx_async_copy.cc now constructs and passes tvm_access_ptr access pointers for both source and destination buffers, instead of passing buffer data and offsets separately. The logic for predicated and non-predicated copies is updated to match the new signature. Vectorized predicated copies are now explicitly unsupported with the new API, and a warning is logged if attempted. [1] [2]

Shared Memory Merging

  • The SharedMemoryRewriter in src/transform/merge_shared_memory_allocations.cc now rewrites ptx_cp_async calls to adjust the dst_access_ptr for merged shared memory buffers, ensuring correct offset calculation and buffer usage. The function checks and rewrites the access pointer instead of manipulating buffer and offset arguments directly.

Code Generation

  • The CUDA code generator (src/target/codegen_cuda.cc) is updated to expect and handle the new ptx_cp_async argument structure, generating the appropriate code for both predicated and non-predicated async copies. Argument validation is added to ensure correct usage. [1] [2]

Summary by CodeRabbit

  • Refactor
    • Standardized async GPU memory-copy calls to use typed access-pointer arguments and simplified 3/4-argument semantics, unifying predicated and non-predicated paths.
  • New Features
    • Added a TileLang intrinsic for PTX async copy with optional predicate support.
  • Bug Fixes / Stability
    • Tightened assembly emission to avoid unnecessary address arithmetic when offsets are absent.
  • Style
    • Removed a few cache-disable calls and minor formatting cleanups in examples.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

github-actions bot commented Jan 2, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 2, 2026

📝 Walkthrough

Walkthrough

This PR refactors ptx_cp_async from a 5–6 argument offset-based API to a 3-argument access-pointer-based API with an optional 4th predicate, updating codegen, transforms, public TileLang API, tests, and PTX emission helpers.

Changes

Cohort / File(s) Summary
CUDA Code Generation
src/target/codegen_cuda.cc
Emission changed to use 3-argument non-predicated (dst, src, bytes) and 4-argument predicated (dst, src, bytes, predicate) cp.async intrinsics (tl::cp_async_gs / conditional). Removed previous dst/src offset handling and adjusted both call sites to new argument ordering.
Transform Passes
src/transform/inject_ptx_async_copy.cc, src/transform/merge_shared_memory_allocations.cc
inject_ptx_async_copy: construct tvm_access_ptr wrappers for dst/src and pass 3 or 4 args (bytes[, predicate]); warns/falls back for unsupported predicated vectorized cases. merge_shared_memory_allocations: now expects 3-arg cp.async with a tvm_access_ptr dst, reconstructs dst access_ptr to merged shared buffer and updates offset multiplication by element size.
Public API / TileLang
tilelang/language/tir/op.py, src/op/builtin.h, src/op/builtin.cc
Added/updated TileLang intrinsic ptx_cp_async to accept (dst_access_ptr, src_access_ptr, bytes, predicate=None) and dispatch to 3-arg or 4-arg intrinsic; new TL builtin op registered (ptx_cp_async).
PTX Assembly Helpers
src/target/ptx.cc, src/target/ptx.h
PTX assembly templates updated to conditionally include offsets only when present; added overloads for PrintCpAsyncAssembly / PrintPredicatedCpAsyncAssembly that accept "no offset" variants and forward to existing implementations.
Tests
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
Updated test call sites to use T.tvm_access_ptr(...) wrappers and 3-arg T.ptx_cp_async(dst_access_ptr, src_access_ptr, bytes) form instead of 6-arg usage.
Examples / Misc
examples/*, testing/python/tilelibrary/*
Removed some tilelang.disable_cache() calls and minor naming/formatting tweaks in example scripts; no functional changes.

Sequence Diagram(s)

sequenceDiagram
    participant TL as TileLang IR / Transform
    participant Inject as inject_ptx_async_copy
    participant Merge as merge_shared_memory_allocations
    participant Codegen as codegen_cuda
    participant PTX as PTX Emitter
    participant GPU as GPU/PTX Runtime

    Note over TL,Inject: IR constructs cp.async using access_ptrs
    TL->>Inject: cp.async(dst_access_ptr, src_access_ptr, bytes[, predicate])
    Inject->>Merge: pass/adjust access_ptrs (handle element counts, predicates)
    Merge->>Codegen: emit cp.async call with dst/src access_ptrs (merged dst offset adjusted)
    Codegen->>PTX: select cp.async intrinsic form (3-arg or 4-arg predicated)
    PTX->>GPU: generate CP.ASYNC assembly (with/without offset) and schedule
    Note right of GPU: GPU executes cp.async, honors predicate if present
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Poem

🐰 Tiny paws on bytes and streams,
Access pointers stitch the seams,
Three arguments hop in stride,
Predicates tag along with pride.
Async copies hum, compact and spry — hooray! 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 41.18% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main refactoring: replacing buffer and offset parameters with access_ptr for ptx_cp_async, which is the primary change across multiple files.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999 LeiWang1999 marked this pull request as draft January 2, 2026 16:35
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)

1455-1475: Avoid duplicated ptx_cp_async lowering branches

The new top-level ptx_cp_async case (lines 1455–1475) fully handles the 3/4‑argument, access‑ptr based API via tl::cp_async_gs[_conditional], but there is a second else if (op->op.same_as(builtin::ptx_cp_async())) later (lines 2283–2301) that will never be reached. Keeping two branches for the same intrinsic—one via tl::cp_async_gs*, one via PrintCpAsyncAssembly—is confusing and makes it unclear which path is authoritative.

Consider either removing or gating the later branch (or changing its condition to a different op) so there is a single, obvious lowering for builtin::ptx_cp_async.

Also applies to: 2283-2301

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6eb5d3 and e7a5b15.

📒 Files selected for processing (5)
  • src/target/codegen_cuda.cc
  • src/transform/inject_ptx_async_copy.cc
  • src/transform/merge_shared_memory_allocations.cc
  • testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
  • tilelang/language/tir/op.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
📚 Learning: 2025-12-15T07:23:50.065Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:50.065Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.

Applied to files:

  • testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
🧬 Code graph analysis (3)
src/transform/inject_ptx_async_copy.cc (1)
tilelang/language/tir/op.py (1)
  • ptx_cp_async (1346-1393)
src/target/codegen_cuda.cc (1)
src/target/ptx.cc (4)
  • PrintCpAsyncAssembly (1324-1351)
  • PrintCpAsyncAssembly (1324-1328)
  • PrintPredicatedCpAsyncAssembly (1353-1409)
  • PrintPredicatedCpAsyncAssembly (1353-1356)
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)
tilelang/language/tir/op.py (3)
  • ptx_cp_async (1346-1393)
  • tvm_access_ptr (651-676)
  • type_annotation (635-648)
🔇 Additional comments (4)
src/transform/inject_ptx_async_copy.cc (1)

95-124: Access‑ptr–based cp.async construction matches the new API

Scalar and supported vectorized cases now construct dst_access_ptr/src_access_ptr via buffer.access_ptr with appropriate rw masks and feed them into builtin::ptx_cp_async as 3‑ or 4‑argument calls, while predicated vectorized copies are explicitly rejected with a warning. This is consistent with the new (dst_access_ptr, src_access_ptr, bytes[, predicate]) interface and preserves existing semantics.

Also applies to: 155-180

tilelang/language/tir/op.py (1)

1346-1393: Python ptx_cp_async wrapper correctly reflects the new access‑ptr API

The updated signature, documentation, and 3‑vs‑4 argument forwarding to _tvm_op.ptx_cp_async line up with the new (dst_access_ptr, src_access_ptr, bytes, predicate=None) convention and with how the C++ passes now consume this intrinsic.

testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)

69-73: Test cp.async invocation correctly updated to access‑ptr form

The test now uses T.tvm_access_ptr(..., rw_mask=2/1) for A_shared/B_shared and calls T.ptx_cp_async(dst_access_ptr, src_access_ptr, 16), which is consistent with the new API while preserving the original “no double fence” assertion.

src/transform/merge_shared_memory_allocations.cc (1)

569-608: [Your rewritten review comment text here]
[Exactly ONE classification tag]

…handling

- Added the `ptx_cp_async` intrinsic for asynchronous memory copy operations in TileLang, supporting both predicated and non-predicated versions.
- Updated the code generation for `ptx_cp_async` to handle offsets correctly and added inline functions for simplified usage.
- Enhanced the `inject_ptx_async_copy` transformation to support predicated vectorized copies.
- Updated example scripts to disable caching and print kernel sources for better debugging and performance insights.
…c copy transformation

- Updated comments for clarity in `codegen_cuda.cc` regarding `ptx_cp_async` arguments.
- Improved formatting in `inject_ptx_async_copy.cc` for better readability of access pointer creation and logging.
- Added a blank line in `op.py` for consistency in the `ptx_cp_async` function definition.
@LeiWang1999 LeiWang1999 marked this pull request as ready for review January 3, 2026 10:17
- Eliminated the implementation for `ptx_cp_async` and `tl::ptx_cp_async` from `codegen_cuda.cc`, streamlining the code by removing unnecessary checks and assembly generation.
- This change focuses on cleaning up the codebase and improving maintainability by removing deprecated functionality.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)

511-512: LGTM! Improved variable naming.

The rename from mod to kernel enhances readability by using a more descriptive identifier for the compiled kernel object.

Note: The file uses both mod_* (lines 543-544) and kernel naming conventions for similar compiled objects. Consider standardizing to kernel throughout for consistency in future refactors.

src/target/codegen_cuda.cc (1)

1454-1497: cp.async emission via access_ptr and tl::cp_async_gs matches the new intrinsic API*

The new builtin::ptx_cp_async / tl::ptx_cp_async handlers correctly:

  • Enforce 3- or 4-argument forms (dst_access_ptr, src_access_ptr, bytes, [predicate]).
  • Treat the first two arguments as fully-lowered tvm_access_ptr expressions and forward them unchanged.
  • Use the third argument as the template byte count for tl::cp_async_gs / tl::cp_async_gs_conditional, aligning with the expected compile-time 4/8/16 sizes.

One minor follow-up: the older inline-PTX branches for builtin::ptx_commit_group / ptx_wait_group later in this function are now unreachable because these new branches handle them first. They could be removed to avoid confusion.

src/transform/inject_ptx_async_copy.cc (1)

173-225: Predicated vectorized cp.async handling and warning fallback look correct, but index_factor is now dead

The predicated vectorized branch:

  • Mirrors the non-predicated index-pattern detection and access_ptr construction.
  • Emits a 4-arg ptx_cp_async only when both base offsets can be determined.
  • Logs a clear warning and falls back to regular buffer store/load when offsets cannot be extracted, which is a safe degradation path.

One minor leftover is index_factor computed earlier in InjectPTX (Lines 77–93) which is no longer used after switching fully to buffer.access_ptr. That variable and the associated comment can be removed or updated to avoid confusion.

tilelang/language/tir/op.py (1)

1390-1395: Consider moving the import to module level.

The implementation logic is correct, but importing tvm.tir inside the function incurs a small overhead on every call. Since line 3 already imports tvm, you could use tvm.tir.call_intrin directly, or add from tvm.tir import call_intrin at the module level alongside the existing imports.

🔎 Suggested refactor

At module level (add after line 8):

 import tvm.tir.op as _tvm_op
+from tvm.tir import call_intrin

Then update the function:

-    from tvm import tir
-
     if predicate is None:
-        return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes)
+        return call_intrin("", _tvm_op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes)
     else:
-        return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate)
+        return call_intrin("", _tvm_op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate)

Note: If the in-function import is intentional (e.g., to avoid circular dependencies), this can be safely ignored.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7a5b15 and 36236ae.

📒 Files selected for processing (11)
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/target/codegen_cuda.cc
  • src/target/ptx.cc
  • src/target/ptx.h
  • src/transform/inject_ptx_async_copy.cc
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
  • tilelang/language/tir/op.py
💤 Files with no reviewable changes (3)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py
  • examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
🧰 Additional context used
🧬 Code graph analysis (5)
src/op/builtin.cc (1)
tilelang/language/tir/op.py (1)
  • ptx_cp_async (1346-1395)
src/target/codegen_cuda.cc (1)
tilelang/language/tir/op.py (1)
  • ptx_cp_async (1346-1395)
src/target/ptx.h (1)
src/target/ptx.cc (4)
  • PrintCpAsyncAssembly (1324-1355)
  • PrintCpAsyncAssembly (1324-1328)
  • PrintPredicatedCpAsyncAssembly (1357-1417)
  • PrintPredicatedCpAsyncAssembly (1357-1360)
src/op/builtin.h (1)
tilelang/language/tir/op.py (1)
  • ptx_cp_async (1346-1395)
src/transform/inject_ptx_async_copy.cc (1)
tilelang/language/tir/op.py (1)
  • ptx_cp_async (1346-1395)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (7)
src/op/builtin.cc (1)

201-209: ptx_cp_async builtin registration looks consistent with existing TL intrinsics

Defining ptx_cp_async as kOpaque with set_num_inputs(-1) matches the intended 3/4-arg intrinsic and follows nearby patterns (e.g., other PTX/TMA ops). No issues from the codegen or pass side.

src/op/builtin.h (1)

307-315: Header declaration and doc for tl::ptx_cp_async align with implementation

The intrinsic comment correctly documents the 3- and 4-argument forms, and the TVM_DLL const Op &ptx_cp_async(); declaration matches the registered op in builtin.cc and the Python wrapper.

src/target/ptx.cc (1)

1343-1350: Empty-offset handling in cp.async assembly is correct and enables no-offset overloads

Conditionally emitting shared_ptr/global_ptr when the corresponding offset strings are empty avoids malformed ptr + expressions and matches the new header overloads that pass "" for “no offset” while preserving behavior for existing non-empty offsets.

Also applies to: 1402-1410

src/transform/inject_ptx_async_copy.cc (2)

99-122: Scalar cp.async lowering via access_ptr is consistent and byte-correct

For the scalar case, computing dst_elem_count / src_elem_count from bytes / elem_type->bytes() and constructing buffer.access_ptr for both sides gives the right element ranges for both “normal” and merged-byte-buffer shared memory. The generated call to builtin::ptx_cp_async(dst_access_ptr, src_access_ptr, bytes[, predicate]) matches the new 3/4-argument intrinsic and the CUDA codegen expectations.


153-171: Vectorized non‑predicated cp.async correctly reuses the new access_ptr path

The non-predicated vectorized path:

  • Restricts to recognizable Ramp/Ramp+Broadcast offset patterns.
  • Reuses the same element-count logic as the scalar case.
  • Emits a 3-arg ptx_cp_async(dst_access_ptr, src_access_ptr, bytes) once base offsets are found.

This keeps behavior aligned with the scalar path and the updated CUDA codegen, while safely skipping unsupported index patterns.

src/target/ptx.h (1)

193-201: cp.async header overloads and docs are consistent with the new implementation

  • Updating the comments to allow “empty for no offset” matches the new {smem_addr} / {global_ptr} handling in ptx.cc.
  • The inline no-offset overloads for PrintCpAsyncAssembly and PrintPredicatedCpAsyncAssembly delegate cleanly to the existing 5/6-parameter versions, avoiding duplication and making call sites simpler when no offset is needed.

Everything here aligns with the refactored cp.async path and doesn’t alter existing behavior.

Also applies to: 209-221, 222-231, 238-252

tilelang/language/tir/op.py (1)

1346-1389: LGTM! Excellent documentation and API design.

The refactored signature is cleaner and more idiomatic for TVM. The comprehensive docstring with concrete examples of both unpredicated and predicated usage will help users understand the access_ptr pattern. The parameter descriptions clearly specify the rw_mask requirements (2 for destination, 1 for source).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant