-
Notifications
You must be signed in to change notification settings - Fork 442
[Refactor] Use access_ptr instead of buffer and offsets for cp async params #1590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR refactors ptx_cp_async from a 5–6 argument offset-based API to a 3-argument access-pointer-based API with an optional 4th predicate, updating codegen, transforms, public TileLang API, tests, and PTX emission helpers. Changes
Sequence Diagram(s)sequenceDiagram
participant TL as TileLang IR / Transform
participant Inject as inject_ptx_async_copy
participant Merge as merge_shared_memory_allocations
participant Codegen as codegen_cuda
participant PTX as PTX Emitter
participant GPU as GPU/PTX Runtime
Note over TL,Inject: IR constructs cp.async using access_ptrs
TL->>Inject: cp.async(dst_access_ptr, src_access_ptr, bytes[, predicate])
Inject->>Merge: pass/adjust access_ptrs (handle element counts, predicates)
Merge->>Codegen: emit cp.async call with dst/src access_ptrs (merged dst offset adjusted)
Codegen->>PTX: select cp.async intrinsic form (3-arg or 4-arg predicated)
PTX->>GPU: generate CP.ASYNC assembly (with/without offset) and schedule
Note right of GPU: GPU executes cp.async, honors predicate if present
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
1455-1475: Avoid duplicatedptx_cp_asynclowering branchesThe new top-level
ptx_cp_asynccase (lines 1455–1475) fully handles the 3/4‑argument, access‑ptr based API viatl::cp_async_gs[_conditional], but there is a secondelse if (op->op.same_as(builtin::ptx_cp_async()))later (lines 2283–2301) that will never be reached. Keeping two branches for the same intrinsic—one viatl::cp_async_gs*, one viaPrintCpAsyncAssembly—is confusing and makes it unclear which path is authoritative.Consider either removing or gating the later branch (or changing its condition to a different op) so there is a single, obvious lowering for
builtin::ptx_cp_async.Also applies to: 2283-2301
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/target/codegen_cuda.ccsrc/transform/inject_ptx_async_copy.ccsrc/transform/merge_shared_memory_allocations.cctesting/python/transform/test_tilelang_transform_inject_fence_proxy.pytilelang/language/tir/op.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
📚 Learning: 2025-12-15T07:23:50.065Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:50.065Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.
Applied to files:
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
🧬 Code graph analysis (3)
src/transform/inject_ptx_async_copy.cc (1)
tilelang/language/tir/op.py (1)
ptx_cp_async(1346-1393)
src/target/codegen_cuda.cc (1)
src/target/ptx.cc (4)
PrintCpAsyncAssembly(1324-1351)PrintCpAsyncAssembly(1324-1328)PrintPredicatedCpAsyncAssembly(1353-1409)PrintPredicatedCpAsyncAssembly(1353-1356)
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)
tilelang/language/tir/op.py (3)
ptx_cp_async(1346-1393)tvm_access_ptr(651-676)type_annotation(635-648)
🔇 Additional comments (4)
src/transform/inject_ptx_async_copy.cc (1)
95-124: Access‑ptr–based cp.async construction matches the new APIScalar and supported vectorized cases now construct
dst_access_ptr/src_access_ptrviabuffer.access_ptrwith appropriate rw masks and feed them intobuiltin::ptx_cp_asyncas 3‑ or 4‑argument calls, while predicated vectorized copies are explicitly rejected with a warning. This is consistent with the new(dst_access_ptr, src_access_ptr, bytes[, predicate])interface and preserves existing semantics.Also applies to: 155-180
tilelang/language/tir/op.py (1)
1346-1393: Pythonptx_cp_asyncwrapper correctly reflects the new access‑ptr APIThe updated signature, documentation, and 3‑vs‑4 argument forwarding to
_tvm_op.ptx_cp_asyncline up with the new(dst_access_ptr, src_access_ptr, bytes, predicate=None)convention and with how the C++ passes now consume this intrinsic.testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)
69-73: Test cp.async invocation correctly updated to access‑ptr formThe test now uses
T.tvm_access_ptr(..., rw_mask=2/1)for A_shared/B_shared and callsT.ptx_cp_async(dst_access_ptr, src_access_ptr, 16), which is consistent with the new API while preserving the original “no double fence” assertion.src/transform/merge_shared_memory_allocations.cc (1)
569-608: [Your rewritten review comment text here]
[Exactly ONE classification tag]
…handling - Added the `ptx_cp_async` intrinsic for asynchronous memory copy operations in TileLang, supporting both predicated and non-predicated versions. - Updated the code generation for `ptx_cp_async` to handle offsets correctly and added inline functions for simplified usage. - Enhanced the `inject_ptx_async_copy` transformation to support predicated vectorized copies. - Updated example scripts to disable caching and print kernel sources for better debugging and performance insights.
…c copy transformation - Updated comments for clarity in `codegen_cuda.cc` regarding `ptx_cp_async` arguments. - Improved formatting in `inject_ptx_async_copy.cc` for better readability of access pointer creation and logging. - Added a blank line in `op.py` for consistency in the `ptx_cp_async` function definition.
- Eliminated the implementation for `ptx_cp_async` and `tl::ptx_cp_async` from `codegen_cuda.cc`, streamlining the code by removing unnecessary checks and assembly generation. - This change focuses on cleaning up the codebase and improving maintainability by removing deprecated functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
511-512: LGTM! Improved variable naming.The rename from
modtokernelenhances readability by using a more descriptive identifier for the compiled kernel object.Note: The file uses both
mod_*(lines 543-544) andkernelnaming conventions for similar compiled objects. Consider standardizing tokernelthroughout for consistency in future refactors.src/target/codegen_cuda.cc (1)
1454-1497: cp.async emission via access_ptr and tl::cp_async_gs matches the new intrinsic API*The new
builtin::ptx_cp_async/tl::ptx_cp_asynchandlers correctly:
- Enforce 3- or 4-argument forms (dst_access_ptr, src_access_ptr, bytes, [predicate]).
- Treat the first two arguments as fully-lowered
tvm_access_ptrexpressions and forward them unchanged.- Use the third argument as the template byte count for
tl::cp_async_gs/tl::cp_async_gs_conditional, aligning with the expected compile-time4/8/16sizes.One minor follow-up: the older inline-PTX branches for
builtin::ptx_commit_group/ptx_wait_grouplater in this function are now unreachable because these new branches handle them first. They could be removed to avoid confusion.src/transform/inject_ptx_async_copy.cc (1)
173-225: Predicated vectorized cp.async handling and warning fallback look correct, butindex_factoris now deadThe predicated vectorized branch:
- Mirrors the non-predicated index-pattern detection and access_ptr construction.
- Emits a 4-arg
ptx_cp_asynconly when both base offsets can be determined.- Logs a clear warning and falls back to regular buffer store/load when offsets cannot be extracted, which is a safe degradation path.
One minor leftover is
index_factorcomputed earlier inInjectPTX(Lines 77–93) which is no longer used after switching fully tobuffer.access_ptr. That variable and the associated comment can be removed or updated to avoid confusion.tilelang/language/tir/op.py (1)
1390-1395: Consider moving the import to module level.The implementation logic is correct, but importing
tvm.tirinside the function incurs a small overhead on every call. Since line 3 already importstvm, you could usetvm.tir.call_intrindirectly, or addfrom tvm.tir import call_intrinat the module level alongside the existing imports.🔎 Suggested refactor
At module level (add after line 8):
import tvm.tir.op as _tvm_op +from tvm.tir import call_intrinThen update the function:
- from tvm import tir - if predicate is None: - return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes) + return call_intrin("", _tvm_op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes) else: - return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate) + return call_intrin("", _tvm_op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate)Note: If the in-function import is intentional (e.g., to avoid circular dependencies), this can be safely ignored.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
examples/flash_attention/example_gqa_bwd_tma_reduce.pyexamples/flash_attention/example_gqa_bwd_tma_reduce_varlen.pyexamples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.pysrc/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.ccsrc/target/ptx.ccsrc/target/ptx.hsrc/transform/inject_ptx_async_copy.cctesting/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.pytilelang/language/tir/op.py
💤 Files with no reviewable changes (3)
- testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
- examples/flash_attention/example_gqa_bwd_tma_reduce.py
- examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py
🧰 Additional context used
🧬 Code graph analysis (5)
src/op/builtin.cc (1)
tilelang/language/tir/op.py (1)
ptx_cp_async(1346-1395)
src/target/codegen_cuda.cc (1)
tilelang/language/tir/op.py (1)
ptx_cp_async(1346-1395)
src/target/ptx.h (1)
src/target/ptx.cc (4)
PrintCpAsyncAssembly(1324-1355)PrintCpAsyncAssembly(1324-1328)PrintPredicatedCpAsyncAssembly(1357-1417)PrintPredicatedCpAsyncAssembly(1357-1360)
src/op/builtin.h (1)
tilelang/language/tir/op.py (1)
ptx_cp_async(1346-1395)
src/transform/inject_ptx_async_copy.cc (1)
tilelang/language/tir/op.py (1)
ptx_cp_async(1346-1395)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (7)
src/op/builtin.cc (1)
201-209: ptx_cp_async builtin registration looks consistent with existing TL intrinsicsDefining
ptx_cp_asyncas kOpaque withset_num_inputs(-1)matches the intended 3/4-arg intrinsic and follows nearby patterns (e.g., other PTX/TMA ops). No issues from the codegen or pass side.src/op/builtin.h (1)
307-315: Header declaration and doc for tl::ptx_cp_async align with implementationThe intrinsic comment correctly documents the 3- and 4-argument forms, and the
TVM_DLL const Op &ptx_cp_async();declaration matches the registered op inbuiltin.ccand the Python wrapper.src/target/ptx.cc (1)
1343-1350: Empty-offset handling in cp.async assembly is correct and enables no-offset overloadsConditionally emitting
shared_ptr/global_ptrwhen the corresponding offset strings are empty avoids malformedptr +expressions and matches the new header overloads that pass""for “no offset” while preserving behavior for existing non-empty offsets.Also applies to: 1402-1410
src/transform/inject_ptx_async_copy.cc (2)
99-122: Scalar cp.async lowering via access_ptr is consistent and byte-correctFor the scalar case, computing
dst_elem_count/src_elem_countfrombytes / elem_type->bytes()and constructingbuffer.access_ptrfor both sides gives the right element ranges for both “normal” and merged-byte-buffer shared memory. The generated call tobuiltin::ptx_cp_async(dst_access_ptr, src_access_ptr, bytes[, predicate])matches the new 3/4-argument intrinsic and the CUDA codegen expectations.
153-171: Vectorized non‑predicated cp.async correctly reuses the new access_ptr pathThe non-predicated vectorized path:
- Restricts to recognizable
Ramp/Ramp+Broadcastoffset patterns.- Reuses the same element-count logic as the scalar case.
- Emits a 3-arg
ptx_cp_async(dst_access_ptr, src_access_ptr, bytes)once base offsets are found.This keeps behavior aligned with the scalar path and the updated CUDA codegen, while safely skipping unsupported index patterns.
src/target/ptx.h (1)
193-201: cp.async header overloads and docs are consistent with the new implementation
- Updating the comments to allow “empty for no offset” matches the new
{smem_addr}/{global_ptr}handling inptx.cc.- The inline no-offset overloads for
PrintCpAsyncAssemblyandPrintPredicatedCpAsyncAssemblydelegate cleanly to the existing 5/6-parameter versions, avoiding duplication and making call sites simpler when no offset is needed.Everything here aligns with the refactored cp.async path and doesn’t alter existing behavior.
Also applies to: 209-221, 222-231, 238-252
tilelang/language/tir/op.py (1)
1346-1389: LGTM! Excellent documentation and API design.The refactored signature is cleaner and more idiomatic for TVM. The comprehensive docstring with concrete examples of both unpredicated and predicated usage will help users understand the access_ptr pattern. The parameter descriptions clearly specify the rw_mask requirements (2 for destination, 1 for source).
This pull request refactors the
ptx_cp_asyncintrinsic throughout the codebase to use access pointer arguments (tvm_access_ptr) instead of raw buffer pointers and offsets. This change modernizes the interface, improves safety, and aligns with best practices for memory access in TVM's TIR. The update affects the intrinsic's Python API, its injection and transformation passes, and related code generation logic.Key changes include:
API and Usage Changes
ptx_cp_asyncPython API intilelang/language/tir/op.pynow takesdst_access_ptr,src_access_ptr, andbytes(with an optionalpredicate), replacing the previous signature that used raw pointers and offsets. The docstring and all usages have been updated accordingly. [1] [2]Transformation and Injection Passes
PTXAsyncCopyInjectorinsrc/transform/inject_ptx_async_copy.ccnow constructs and passestvm_access_ptraccess pointers for both source and destination buffers, instead of passing buffer data and offsets separately. The logic for predicated and non-predicated copies is updated to match the new signature. Vectorized predicated copies are now explicitly unsupported with the new API, and a warning is logged if attempted. [1] [2]Shared Memory Merging
SharedMemoryRewriterinsrc/transform/merge_shared_memory_allocations.ccnow rewritesptx_cp_asynccalls to adjust thedst_access_ptrfor merged shared memory buffers, ensuring correct offset calculation and buffer usage. The function checks and rewrites the access pointer instead of manipulating buffer and offset arguments directly.Code Generation
src/target/codegen_cuda.cc) is updated to expect and handle the newptx_cp_asyncargument structure, generating the appropriate code for both predicated and non-predicated async copies. Argument validation is added to ensure correct usage. [1] [2]Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.