[Feature] Support region as input of T.cumsum#1426
Conversation
- Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer - This enables operations on buffer slices/regions like: T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0) - Update cumsum_fragment to handle region inputs properly - Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes Fixes tile-ai#879
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAccept region/slice inputs for cumsum and copy destinations; cumsum/cumsum_fragment now accept tir.Buffer, tir.BufferRegion, or tir.BufferLoad, infer shapes/dtypes and validate dst, route local.fragment cases to cumsum_fragment or emit tl.tileop.cumsum otherwise. Added 1D and 2D region-sliced tests (dim/reverse variants) validating against PyTorch. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller / API user
participant cumsum as tilelang.cumsum()
participant infer as retrieve_shape/_get_buffer
participant scope as ScopeCheck (base buffer)
participant fragment as cumsum_fragment
participant intrinsic as tl.tileop.cumsum
Caller->>cumsum: call cumsum(src_region_or_buf, dst?, dim, reverse)
cumsum->>infer: infer shape, dtype, base_buffer from src/dst
infer-->>cumsum: shape, dtype, base_buffer
cumsum->>scope: check base_buffer.scope
alt base_buffer.scope == "local.fragment"
cumsum->>fragment: invoke cumsum_fragment(src_region, dst_region, dim, reverse)
fragment-->>Caller: write result into fragment region (in-place)
else
cumsum->>intrinsic: emit tl.tileop.cumsum(src_region, dst_region, dim, reverse)
intrinsic-->>Caller: intrinsic emitted (out-of-place or to dst region)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@Dayuxiaoshui Thank! we're good to go to execute |
- Add comprehensive docstring for cumsum_fragment function - Format code according to ruff style guidelines
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/language/reduce.py (3)
244-284: Fix negative-dimension bounds check (dim == -ndimis currently rejected).
Line 337 usesdim <= -len(shape), which incorrectly rejects valid Python-styledim == -len(shape).- if dim >= len(shape) or dim <= -len(shape): + if dim >= len(shape) or dim < -len(shape): raise ValueError(f"Dimension {dim} is out of bounds for buffer with shape {shape}")
286-356: Addsrc/dstshape compatibility validation for out-of-place cumsum.
Right nowcumsum()validatesdimagainstsrcbut doesn’t ensureretrieve_shape(dst)matchesretrieve_shape(src)whendstis provided (esp. important withBufferRegionextents).def cumsum( src: tir.Buffer | tir.BufferRegion | tir.BufferLoad, dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad | None = None, dim: int = 0, reverse: bool = False, ): @@ - if dst is None: - dst = src + if dst is None: + dst = src + else: + dst_shape = retrieve_shape(dst) + if len(dst_shape) != len(shape) or any(list(dst_shape)[i] != list(shape)[i] for i in range(len(shape))): + raise ValueError(f"cumsum dst shape {dst_shape} must match src shape {shape}")
245-284: Fix type annotation forcopy()dstparameter to accepttir.BufferRegion.
cumsum_fragment()allowsdst: tir.BufferRegion, butcopy()at line 16 oftilelang/language/copy.pydeclaresdst: tir.Buffer | tir.BufferLoad, excludingBufferRegion. The docstring (line 35) claims both sides acceptBufferRegion, andto_buffer_region()can handle it, but the type signature prevents callers from passingBufferRegionasdstwithout type checker errors. Update the type annotation to:dst: tir.Buffer | tir.BufferLoad | tir.BufferRegion.
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_cumsum.py (1)
278-297: Add at least one non-divisible size test to cover tail regions (potential OOB / masking bugs).
All current calls use sizes divisible bychunk_size/block_*, so you’re not exercising last-tile tail handling.def test_cumsum_region_1d(): @@ run_cumsum_region_1d(2048, 256) + # Tail coverage (non-divisible) + run_cumsum_region_1d(1000, 128)(And similarly consider one 2D tail case like
M=1000, N=1000, block=128if the backend is expected to support it.)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/language/test_tilelang_language_cumsum.py(1 hunks)tilelang/language/reduce.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/reduce.py (2)
tilelang/utils/language.py (3)
to_buffer_region(193-236)retrieve_shape(239-256)_get_buffer(12-27)tilelang/language/v2/annot.py (2)
Buffer(562-585)shape(577-577)
testing/python/language/test_tilelang_language_cumsum.py (2)
tilelang/language/reduce.py (1)
cumsum(286-356)tilelang/language/copy.py (1)
copy(14-95)
🪛 Ruff (0.14.8)
tilelang/language/reduce.py
338-338: Avoid specifying long messages outside the exception class
(TRY003)
- Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape)) - Add src/dst shape compatibility validation for out-of-place cumsum - Update copy() type annotation to accept BufferRegion as dst parameter - Fix test in-place mutation issues by using out-of-place cumsum operations - Add non-divisible size test cases for tail region coverage
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
testing/python/language/test_tilelang_language_cumsum.py(1 hunks)tilelang/language/copy.py(2 hunks)tilelang/language/reduce.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/reduce.py (1)
tilelang/utils/language.py (2)
retrieve_shape(239-256)_get_buffer(12-27)
testing/python/language/test_tilelang_language_cumsum.py (1)
tilelang/language/reduce.py (1)
cumsum(286-365)
tilelang/language/copy.py (1)
tilelang/language/v2/annot.py (1)
Buffer(562-585)
🪛 Ruff (0.14.8)
tilelang/language/reduce.py
338-338: Avoid specifying long messages outside the exception class
(TRY003)
348-348: Avoid specifying long messages outside the exception class
(TRY003)
352-352: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (8)
tilelang/language/copy.py (1)
16-16: LGTM! Type annotation expanded to match existing implementation.The
dstparameter type annotation now correctly reflects thatBufferRegionis already supported by the existingget_extenthelper andto_buffer_regioncall. This aligns the API documentation with actual functionality.Also applies to: 25-25
tilelang/language/reduce.py (4)
6-6: LGTM! Required helpers imported for region support.The
retrieve_shapeand_get_bufferutilities are correctly imported to enable shape extraction and buffer access fromBufferRegionandBufferLoadtypes.
245-283: LGTM!cumsum_fragmentcorrectly extended to handle region inputs.The function now properly:
- Extracts shape via
retrieve_shapeto handle all three input types- Derives dtype correctly - directly from
srcif it's aBuffer, otherwise from the underlying buffer- Allocates shared memory with the extracted shape/dtype for the copy-cumsum-copy pattern
344-352: LGTM! Shape validation correctly handles symbolic dimensions.Using
tir.analysis.expr_deep_equalfor dimension comparison is the right approach to handle both concrete and symbolic (dynamic) shape expressions in TVM IR.
354-357: LGTM! Scope-based routing correctly dispatches fragment inputs.The logic extracts the underlying buffer to check its scope, ensuring that
BufferRegionandBufferLoadinputs derived from fragment buffers are correctly routed tocumsum_fragment.testing/python/language/test_tilelang_language_cumsum.py (3)
177-199: LGTM! Out-of-place operation correctly addresses the previous review feedback.The test now uses separate
srcanddstregions (InputG_fragment[...]→OutputG_fragment[...]), which correctly validates the region-based cumsum feature without in-place mutation issues.
228-250: LGTM! 2D region test correctly structured.The test properly demonstrates region-based cumsum on 2D slices with out-of-place operation. The same boundary concern noted for the 1D test applies here for non-divisible test cases (line 305).
283-305: Good test coverage for region-based cumsum.The tests cover:
- 1D and 2D regions
- Forward and reverse modes
- Different chunk/block sizes
- Both dimensions for 2D
The non-divisible size tests (lines 293, 305) should be verified to ensure they pass or are intentionally testing edge case handling.
- Add bounds clamping using T.min() for chunk_end calculations - Prevents accessing beyond tensor bounds for non-divisible sizes - Matches reference implementation behavior - Fixes both 1D and 2D region test cases
- Remove T.min() which cannot be used directly in slice indices - Use chunk_start + chunk_size form instead - Rely on system's automatic bounds checking for non-divisible sizes - Update comments to reflect this approach
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
@LeiWang1999 Interesting question, thank you for reviewing. I will try to solve this. Thank you for reviewing, it is an honor to participate in the work on tilelang. |
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/reduce.cc(1 hunks)testing/python/language/test_tilelang_language_cumsum.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_cumsum.py (2)
tilelang/language/copy.py (1)
copy(14-95)tilelang/language/reduce.py (1)
cumsum(286-365)
🔇 Additional comments (3)
src/op/reduce.cc (1)
536-558: LGTM! Region extents correctly replace buffer shape for slice support.The change properly extracts extents from
srcRegion_->regionand uses them to derivendimand pass the correct dimensions toCumSum1D::runandCumSum2D::run. This enables cumsum to operate on buffer slices/regions as intended by the PR objectives.testing/python/language/test_tilelang_language_cumsum.py (2)
286-296: Good test coverage for region-based cumsum.The tests cover normal operation, reverse mode, various chunk sizes, and importantly include tail cases with non-divisible sizes (line 296). This comprehensive coverage will help validate the bounds handling once the concern raised in earlier comments is addressed.
299-308: Good test coverage for 2D region-based cumsum.The 2D tests comprehensively cover both dimensions (dim=0 and dim=1), reverse mode, and non-divisible sizes (line 308), providing thorough validation of the region-based cumsum functionality.
| with T.Kernel(T.ceildiv(N, chunk_size), threads=chunk_size) as bx: | ||
| i = bx | ||
| chunk_start = i * chunk_size | ||
| # Copy region to shared memory first (cumsum only supports shared memory) | ||
| A_shared = T.alloc_shared((chunk_size,), dtype) | ||
| T.copy(InputG_fragment[chunk_start : chunk_start + chunk_size], A_shared) | ||
| # Test cumsum with region input - in-place operation on shared memory | ||
| # This demonstrates the feature: T.cumsum(region, dim=0) | ||
| T.cumsum(src=A_shared, dim=0, reverse=reverse) | ||
| # Copy result back to global memory | ||
| T.copy(A_shared, OutputG_fragment[chunk_start : chunk_start + chunk_size]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l testing/python/language/test_tilelang_language_cumsum.pyRepository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
# First, let's see the test structure and the reference implementation
head -220 testing/python/language/test_tilelang_language_cumsum.py | tail -50Repository: tile-ai/tilelang
Length of output: 2074
🏁 Script executed:
# Get lines around 212 (reference implementation mentioned)
sed -n '200,220p' testing/python/language/test_tilelang_language_cumsum.pyRepository: tile-ai/tilelang
Length of output: 881
🏁 Script executed:
# Get the test function and full context around line 296
sed -n '177,198p' testing/python/language/test_tilelang_language_cumsum.pyRepository: tile-ai/tilelang
Length of output: 1109
🏁 Script executed:
# Find where N=1000 and chunk_size=128 test is defined
sed -n '290,310p' testing/python/language/test_tilelang_language_cumsum.pyRepository: tile-ai/tilelang
Length of output: 774
🏁 Script executed:
cd testing/python/language && python -m pytest test_tilelang_language_cumsum.py::test_cumsum_region_1d -xvs 2>&1 | head -100Repository: tile-ai/tilelang
Length of output: 146
🏁 Script executed:
# Also check if there are any other uses of this pattern that might inform how slicing works
rg "T\.copy.*\[.*:.*\]" testing/python/language/ -A 2 -B 2 | head -60Repository: tile-ai/tilelang
Length of output: 5700
Add bounds clamping to prevent out-of-bounds access on tail blocks.
The kernel does not clamp slice bounds (lines 191, 196), but the reference implementation (lines 212-213) explicitly uses end = min(start + chunk_size, N) to handle non-divisible tensor sizes. For the test case run_cumsum_region_1d(1000, 128) at line 306, the last block (i=7) computes chunk_start=896 and attempts to access [896:1024], which exceeds the tensor bounds [0:999].
The kernel must be updated to clamp the slice bounds. Replace lines 188-196 with:
chunk_start = i * chunk_size
chunk_end = T.min(chunk_start + chunk_size, N)
A_shared = T.alloc_shared((chunk_size,), dtype)
T.copy(InputG_fragment[chunk_start : chunk_end], A_shared[:chunk_end - chunk_start])
T.cumsum(src=A_shared[:chunk_end - chunk_start], dim=0, reverse=reverse)
T.copy(A_shared[:chunk_end - chunk_start], OutputG_fragment[chunk_start : chunk_end])| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): | ||
| chunk_start_M = by * block_M | ||
| chunk_start_N = bx * block_N | ||
| # Copy region to shared memory first (cumsum only supports shared memory) | ||
| A_shared = T.alloc_shared((block_M, block_N), dtype) | ||
| T.copy( | ||
| InputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N], | ||
| A_shared, | ||
| ) | ||
| # Test cumsum with 2D region input - in-place operation on shared memory | ||
| T.cumsum(src=A_shared, dim=dim, reverse=reverse) | ||
| # Copy result back to global memory | ||
| T.copy( | ||
| A_shared, | ||
| OutputG_fragment[chunk_start_M : chunk_start_M + block_M, chunk_start_N : chunk_start_N + block_N], | ||
| ) |
There was a problem hiding this comment.
Add bounds clamping to the 2D kernel for tail blocks
The 2D kernel (lines 242, 250) slices without clamping: chunk_start_M + block_M and chunk_start_N + block_N will exceed array dimensions for non-aligned tail blocks. For example, with the test parameters M=1000, N=1000, block_M=128, block_N=128, the last blocks (index 7) would compute slices [896:1024] when the valid range is only [0:1000]. The reference implementation (lines 269, 271) correctly uses min() for clamping—the kernel should match this pattern:
chunk_end_M = T.min(chunk_start_M + block_M, M)
chunk_end_N = T.min(chunk_start_N + block_N, N)
# Then use chunk_end_M and chunk_end_N in slices instead of direct addition🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_cumsum.py around lines 236 to
251, the 2D kernel uses slices like chunk_start_M + block_M and chunk_start_N +
block_N which can exceed array bounds for tail blocks; compute chunk_end_M =
T.min(chunk_start_M + block_M, M) and chunk_end_N = T.min(chunk_start_N +
block_N, N) and use those clamp values in all slice ranges for the two T.copy
calls and the T.cumsum invocation so the kernel only touches valid elements
(adjust slice endpoints and any dependent sizes to use chunk_end_M/chunk_end_N
instead of raw addition).
LeiWang1999
left a comment
There was a problem hiding this comment.
Thanks for your contribution! I left some comments :) and I think we can merge this pull request after those comments are addressed.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tilelang/language/reduce.py (1)
286-290: Critical issue from previous review remains unresolved.As noted in the previous review, the CUDA lowering for cumsum slices uses the underlying buffer shape instead of region extents. This affects both the
cumsum_fragmentpath (which calls the intrinsic on line 275-282) and the direct intrinsic path (lines 357-364). When operating on a region smaller than the full buffer, this will scan beyond the intended region boundaries, potentially causing out-of-bounds access and incorrect results.
🧹 Nitpick comments (2)
tilelang/language/reduce.py (2)
269-272: Consider simplifying dtype extraction.Since
_get_buffer(src)returns the underlying buffer for all input types (including whensrcis already a Buffer), you could simplify this to always usesrc_buffer.dtype:- # Get dtype from the buffer - if isinstance(src, tir.Buffer): - dtype = src.dtype - else: - dtype = src_buffer.dtype + dtype = src_buffer.dtype
338-338: Consider extracting error messages to constants (optional).Static analysis flags these long error messages (TRY003). While the current messages are informative and helpful for users, you could optionally extract them to module-level constants if you prefer to follow this style guideline strictly.
Also applies to: 348-348, 352-352
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/reduce.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/reduce.py (1)
tilelang/utils/language.py (4)
to_buffer_region(193-236)retrieve_shape(239-256)_get_buffer(12-27)is_fragment(104-115)
🪛 Ruff (0.14.8)
tilelang/language/reduce.py
338-338: Avoid specifying long messages outside the exception class
(TRY003)
348-348: Avoid specifying long messages outside the exception class
(TRY003)
352-352: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
tilelang/language/reduce.py (3)
6-6: LGTM!The added imports correctly support the new region/slice input functionality for cumsum operations.
344-352: Shape validation is thorough and correct.The validation correctly checks both rank and per-dimension compatibility between src and dst using
expr_deep_equal, which properly handles symbolic TIR expressions.
292-330: Excellent documentation with clear examples.The updated docstring comprehensively covers the new region/slice functionality with practical examples, making it easy for users to understand how to operate on buffer slices.
Fixes #879
Summary by CodeRabbit
New Features
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.