-
Notifications
You must be signed in to change notification settings - Fork 446
[Fix] Add register to read A ptr in test_tilelang_language_cooperative.py
#1593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Add register to read A ptr in test_tilelang_language_cooperative.py
#1593
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughBlock size in the cooperative grid-synchronized kernel was reduced from 128 to 64. The kernel now allocates an in-kernel temporary Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_cooperative.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_cooperative.py
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_cooperative.py (2)
tilelang/language/allocate.py (1)
alloc_fragment(71-82)tilelang/language/builtin.py (1)
sync_grid(721-723)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (3)
testing/python/language/test_tilelang_language_cooperative.py (3)
9-9: Block size reduced to 64.The block size reduction from 128 to 64 is noted. This results in 16 blocks (each with 128 threads) processing the 1024 elements. The change appears intentional for this test.
14-14: Good addition of local fragment buffer.The allocation of
A_localas a staging buffer is the fix referenced in the PR title. This avoids potential read-after-write hazards when reading from and writing to the same global array in parallel.
36-36: Appropriate use oftorch.emptyfor kernel-initialized tensor.The change from
torch.randtotorch.emptyis appropriate since the kernel explicitly initializes all elements in the first pass (lines 16-18) before reading any values. This avoids unnecessary initialization overhead.
| for i in T.Parallel(block): | ||
| if n_idx + i < N: | ||
| A[n_idx + i] = A[n_idx + i] + A[N - n_idx - i - 1] | ||
| A_local[i] = A[N - n_idx - i - 1] | ||
| T.sync_grid() | ||
| if n_idx + i < N: | ||
| A[n_idx + i] = A[n_idx + i] + A_local[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: T.sync_grid() must be outside the parallel loop.
The T.sync_grid() call at line 23 is placed inside the T.Parallel(block) loop. Grid synchronization must occur outside parallel constructs to properly synchronize all threads across all blocks. Calling it inside the loop will cause each parallel iteration to execute the sync, leading to deadlocks or undefined behavior.
🔎 Proposed fix: Move sync outside the loop
- for i in T.Parallel(block):
- if n_idx + i < N:
- A_local[i] = A[N - n_idx - i - 1]
- T.sync_grid()
- if n_idx + i < N:
- A[n_idx + i] = A[n_idx + i] + A_local[i]
+ for i in T.Parallel(block):
+ if n_idx + i < N:
+ A_local[i] = A[N - n_idx - i - 1]
+ T.sync_grid()
+ for i in T.Parallel(block):
+ if n_idx + i < N:
+ A[n_idx + i] = A[n_idx + i] + A_local[i]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i in T.Parallel(block): | |
| if n_idx + i < N: | |
| A[n_idx + i] = A[n_idx + i] + A[N - n_idx - i - 1] | |
| A_local[i] = A[N - n_idx - i - 1] | |
| T.sync_grid() | |
| if n_idx + i < N: | |
| A[n_idx + i] = A[n_idx + i] + A_local[i] | |
| for i in T.Parallel(block): | |
| if n_idx + i < N: | |
| A_local[i] = A[N - n_idx - i - 1] | |
| T.sync_grid() | |
| for i in T.Parallel(block): | |
| if n_idx + i < N: | |
| A[n_idx + i] = A[n_idx + i] + A_local[i] |
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_cooperative.py around lines 20
to 25, T.sync_grid() is incorrectly placed inside the T.Parallel(block) loop;
move the T.sync_grid() call to immediately after the parallel loop
(outside/after its block) so the grid synchronization happens once after all
parallel iterations complete, keeping the existing conditional checks and
assignments unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
testing/python/language/test_tilelang_language_cooperative.py (1)
20-25: Critical:T.sync_grid()must be outside the parallel loop.The
T.sync_grid()call at line 23 is placed inside theT.Parallel(block)loop. Grid synchronization must occur outside parallel constructs to properly synchronize all threads across all blocks. Calling it inside the loop will cause each parallel iteration to execute the sync, leading to deadlocks or undefined behavior.This issue was already flagged in a previous review and has not been addressed.
🔎 Proposed fix: Move sync outside the loop
- for i in T.Parallel(block): - if n_idx + i < N: - A_local[i] = A[N - n_idx - i - 1] - T.sync_grid() - if n_idx + i < N: - A[n_idx + i] = A[n_idx + i] + A_local[i] + for i in T.Parallel(block): + if n_idx + i < N: + A_local[i] = A[N - n_idx - i - 1] + T.sync_grid() + for i in T.Parallel(block): + if n_idx + i < N: + A[n_idx + i] = A[n_idx + i] + A_local[i]
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_cooperative.py (1)
14-14: Use proper tuple syntax for the shape parameter.The shape parameter uses
(block), which is just an integer in Python, not a tuple. According to thealloc_fragmentsignature in the relevant code snippets, theshapeparameter should be a tuple. For a 1D buffer, use(block,)with a trailing comma to create a single-element tuple.🔎 Proposed fix
- A_local = T.alloc_fragment((block), dtype=T.float32) + A_local = T.alloc_fragment((block,), dtype=T.float32)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_cooperative.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_cooperative.py
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_cooperative.py (2)
tilelang/language/allocate.py (1)
alloc_fragment(71-82)tilelang/language/builtin.py (1)
sync_grid(721-723)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (1)
testing/python/language/test_tilelang_language_cooperative.py (1)
9-9: Block size and thread count mismatch may be intentional for this grid synchronization test.The kernel launches with
threads=128but the parallel loops process onlyblock=64iterations, which means 64 threads will be idle during those loops. If this is unintended, consider changingthreads=128tothreads=64. However, for grid synchronization testing, having extra threads may be deliberate—verify with the team whether this thread count is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
testing/python/language/test_tilelang_language_cooperative.py (1)
19-22: Critical:T.sync_grid()is still inside the parallel loop.This issue was flagged in a previous review but appears unaddressed. The
T.sync_grid()at line 21 is inside theT.Parallel(block)loop, which will cause each thread to execute the sync multiple times per iteration rather than once collectively across the grid.🔎 Proposed fix: Split into two parallel loops with sync between them
for i in T.Parallel(block): A_local[i] = A[N - n_idx - i - 1] - T.sync_grid() + T.sync_grid() + for i in T.Parallel(block): A[n_idx + i] = A[n_idx + i] + A_local[i]
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_cooperative.py (1)
16-17: Consider adding bounds checks for general safety.The bounds check
if n_idx + i < Nwas removed. While safe for the current test (N=1024 is divisible by block=64), this could cause out-of-bounds writes if N is not a multiple of block in other scenarios.🔎 Suggested bounds-safe version
for i in T.Parallel(block): - A[n_idx + i] = n_idx + i + if n_idx + i < N: + A[n_idx + i] = n_idx + i
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/language/test_tilelang_language_cooperative.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_cooperative.py
🔇 Additional comments (2)
testing/python/language/test_tilelang_language_cooperative.py (2)
9-14: LGTM on kernel setup and local allocation.The block size reduction to 64 and the introduction of
A_localfragment for intermediate storage is appropriate for this cooperative kernel pattern.
29-36: Test logic is correct, pending kernel fix.The test correctly expects all elements to equal
N - 1(1023) after the kernel completes (sinceA[i] = i + (N - i - 1)). However, the test will only pass once theT.sync_grid()placement issue in the kernel is resolved.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.