-
Notifications
You must be signed in to change notification settings - Fork 450
[BugFix] Update buffer access in TensorCoreIntrinEmitter to handle variable dimensions correctly #1794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…riable dimensions correctly
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughModified buffer indexing in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/intrinsics/mma_macro_generator.py (4)
254-273:⚠️ Potential issue | 🔴 CriticalCritical: Star expression in index causes SyntaxError on Python < 3.11.
The syntax
A_buf[*A_other, ...]using star unpacking directly in indexing brackets was introduced in Python 3.11 (PEP 646). This causes a compilation failure on Python 3.9/3.10 as confirmed by the CI pipeline.Use tuple concatenation to maintain backward compatibility:
🐛 Proposed fix for Python version compatibility
- A_other = [r.min for r in A_region.region[:-2]] + A_other = tuple(r.min for r in A_region.region[:-2]) `@T.macro` def _warp_ld_a_fp64( A_local_buf, A_shared_buf, ki, thread_binding, rk=0, ): tx, _, warp_m = self.extract_thread_binding(thread_binding) for i in T.serial(warp_rows): wi = warp_m * warp_row_tiles + i * micro_size_x wk = rk * chunk + ki * micro_size_k mi = tx // micro_size_k mk = tx % micro_size_k if a_transposed: - A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a] = A_buf[A_other + (A_base0 + wk + mk, A_base1 + wi + mi)] else: - A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a] = A_buf[A_other + (A_base0 + wi + mi, A_base1 + wk + mk)]
308-345:⚠️ Potential issue | 🔴 CriticalSame Python 3.11+ syntax issue in non-fp64 ldmatrix_a path.
Apply the same tuple concatenation fix for backward compatibility.
🐛 Proposed fix
- A_other = [r.min for r in A_region.region[:-2]] + A_other = tuple(r.min for r in A_region.region[:-2]) A_stride_last = A_buf.shape[-1] `@T.macro` def _warp_ldmatrix_a( ... ): ... for i in T.serial(warp_rows): # Assign A_shared_buf_elem wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k - A_shared_buf_elem = A_buf[*A_other, A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[*A_other, A_base0 + wi, A_base1 + wk] + A_shared_buf_elem = A_buf[A_other + (A_base0 + wk, A_base1 + wi)] if a_transposed else A_buf[A_other + (A_base0 + wi, A_base1 + wk)] if ldmatrix_available: ... else: for j in T.serial(local_size_a): mi, mk = mma_load_layout(tx, j) if a_transposed: - A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi] + A_local_buf[i * local_size_a + j] = A_buf[A_other + (A_base0 + wk + mk, A_base1 + wi + mi)] else: - A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk] + A_local_buf[i * local_size_a + j] = A_buf[A_other + (A_base0 + wi + mi, A_base1 + wk + mk)]
366-385:⚠️ Potential issue | 🔴 CriticalSame Python 3.11+ syntax issue in fp64 ldmatrix_b path.
Apply the same tuple concatenation fix.
🐛 Proposed fix
- B_other = [r.min for r in B_region.region[:-2]] + B_other = tuple(r.min for r in B_region.region[:-2]) `@T.macro` def _warp_ld_b_fp64( ... ): ... for j in T.serial(warp_cols): ... if b_transposed: - B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[j * local_size_b] = B_buf[B_other + (B_base0 + wi + mi, B_base1 + wk + mk)] else: - B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[j * local_size_b] = B_buf[B_other + (B_base0 + wk + mk, B_base1 + wi + mi)]
404-464:⚠️ Potential issue | 🔴 CriticalSame Python 3.11+ syntax issue in non-fp64 ldmatrix_b path.
Apply the same tuple concatenation fix.
🐛 Proposed fix
- B_other = [r.min for r in B_region.region[:-2]] + B_other = tuple(r.min for r in B_region.region[:-2]) B_stride_last = B_buf.shape[-1] ... `@T.macro` def _warp_ldmatrix_b( ... ): ... for i in T.serial(warp_cols): ... if ldmatrix_available: - B_shared_buf_elem = B_buf[*B_other, B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[*B_other, B_base0 + wk, B_base1 + wi] + B_shared_buf_elem = B_buf[B_other + (B_base0 + wi, B_base1 + wk)] if b_transposed else B_buf[B_other + (B_base0 + wk, B_base1 + wi)] ... else: for j in T.serial(local_size_b): mi, mk = mma_load_layout(tx, j) if b_transposed: - B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk] + B_local_buf[i * local_size_b + j] = B_buf[B_other + (B_base0 + wi + mi, B_base1 + wk + mk)] else: - B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi] + B_local_buf[i * local_size_b + j] = B_buf[B_other + (B_base0 + wk + mk, B_base1 + wi + mi)]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request fixes a bug in the TensorCoreIntrinEmitter class in mma_macro_generator.py where buffer access did not correctly handle input matrices with more than two dimensions. The fix adds support for multi-dimensional buffers by extracting and prepending leading dimension indices when accessing buffer elements.
Changes:
- Added extraction of leading dimensions (beyond the last 2) into
A_otherandB_othervariables - Updated all buffer indexing operations to prepend leading dimensions using tuple concatenation
- Applied consistently across both
ldmatrix_aandldmatrix_bfunctions, for both FP64 and general cases, and for both transposed and non-transposed variants
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This PR fixes a bug in the TensorCoreIntrinEmitter for the gemm_v2 version: it does not handle cases where the input matrices have more than two dimensions.
Summary by CodeRabbit