Skip to content

Conversation

@xwhzz
Copy link
Contributor

@xwhzz xwhzz commented Feb 5, 2026

This PR fixes a bug in the TensorCoreIntrinEmitter for the gemm_v2 version: it does not handle cases where the input matrices have more than two dimensions.

Summary by CodeRabbit

  • Refactor
    • Updated internal buffer indexing for matrix multiplication to support multi-region reads, improving handling of multiple contiguous data regions.
    • Applied across numeric precision paths and load routines to standardize indexing behavior without changing public interfaces or error handling.

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Modified buffer indexing in tilelang/intrinsics/mma_macro_generator.py to support multi-region loading by introducing A_other and B_other indices and updating affected FP64 and ldmatrix loading paths to use multi-entry A_buf/B_buf indexing.

Changes

Cohort / File(s) Summary
MMA Buffer Indexing
tilelang/intrinsics/mma_macro_generator.py
Added A_other and B_other to collect region-min values for all but the last two regions. Rewrote FP64 and ldmatrix (A/B, transposed and non-transposed) loads to use A_buf[\*A_other, ...] and B_buf[\*B_other, ...] multi-entry indexing instead of direct A_base0/A_base1 indexing. Net diff: +22/-10 lines.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • LeiWang1999

Poem

🐰 I nibbled through indices, a curious chore,

A_other, B_other now fetch regions galore,
Buffers aligned, hopping through rows,
Loads relay smoothly where data flows,
A rabbit's cheer for indexing more!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[BugFix] Update buffer access in TensorCoreIntrinEmitter to handle variable dimensions correctly' accurately describes the main change: updating buffer access patterns in the code to handle variable dimensions, which aligns with the summary's description of introducing A_other and B_other indices and rewiring buffer access logic.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/intrinsics/mma_macro_generator.py (4)

254-273: ⚠️ Potential issue | 🔴 Critical

Critical: Star expression in index causes SyntaxError on Python < 3.11.

The syntax A_buf[*A_other, ...] using star unpacking directly in indexing brackets was introduced in Python 3.11 (PEP 646). This causes a compilation failure on Python 3.9/3.10 as confirmed by the CI pipeline.

Use tuple concatenation to maintain backward compatibility:

🐛 Proposed fix for Python version compatibility
-            A_other = [r.min for r in A_region.region[:-2]]
+            A_other = tuple(r.min for r in A_region.region[:-2])
 
             `@T.macro`
             def _warp_ld_a_fp64(
                 A_local_buf,
                 A_shared_buf,
                 ki,
                 thread_binding,
                 rk=0,
             ):
                 tx, _, warp_m = self.extract_thread_binding(thread_binding)
                 for i in T.serial(warp_rows):
                     wi = warp_m * warp_row_tiles + i * micro_size_x
                     wk = rk * chunk + ki * micro_size_k
                     mi = tx // micro_size_k
                     mk = tx % micro_size_k
                     if a_transposed:
-                        A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi]
+                        A_local_buf[i * local_size_a] = A_buf[A_other + (A_base0 + wk + mk, A_base1 + wi + mi)]
                     else:
-                        A_local_buf[i * local_size_a] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk]
+                        A_local_buf[i * local_size_a] = A_buf[A_other + (A_base0 + wi + mi, A_base1 + wk + mk)]

308-345: ⚠️ Potential issue | 🔴 Critical

Same Python 3.11+ syntax issue in non-fp64 ldmatrix_a path.

Apply the same tuple concatenation fix for backward compatibility.

🐛 Proposed fix
-        A_other = [r.min for r in A_region.region[:-2]]
+        A_other = tuple(r.min for r in A_region.region[:-2])
         A_stride_last = A_buf.shape[-1]
 
         `@T.macro`
         def _warp_ldmatrix_a(
             ...
         ):
             ...
             for i in T.serial(warp_rows):
                 # Assign A_shared_buf_elem
                 wi, wk = warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * micro_size_k
-                A_shared_buf_elem = A_buf[*A_other, A_base0 + wk, A_base1 + wi] if a_transposed else A_buf[*A_other, A_base0 + wi, A_base1 + wk]
+                A_shared_buf_elem = A_buf[A_other + (A_base0 + wk, A_base1 + wi)] if a_transposed else A_buf[A_other + (A_base0 + wi, A_base1 + wk)]
 
                 if ldmatrix_available:
                     ...
                 else:
                     for j in T.serial(local_size_a):
                         mi, mk = mma_load_layout(tx, j)
                         if a_transposed:
-                            A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wk + mk, A_base1 + wi + mi]
+                            A_local_buf[i * local_size_a + j] = A_buf[A_other + (A_base0 + wk + mk, A_base1 + wi + mi)]
                         else:
-                            A_local_buf[i * local_size_a + j] = A_buf[*A_other, A_base0 + wi + mi, A_base1 + wk + mk]
+                            A_local_buf[i * local_size_a + j] = A_buf[A_other + (A_base0 + wi + mi, A_base1 + wk + mk)]

366-385: ⚠️ Potential issue | 🔴 Critical

Same Python 3.11+ syntax issue in fp64 ldmatrix_b path.

Apply the same tuple concatenation fix.

🐛 Proposed fix
-            B_other = [r.min for r in B_region.region[:-2]]
+            B_other = tuple(r.min for r in B_region.region[:-2])
 
             `@T.macro`
             def _warp_ld_b_fp64(
                 ...
             ):
                 ...
                 for j in T.serial(warp_cols):
                     ...
                     if b_transposed:
-                        B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk]
+                        B_local_buf[j * local_size_b] = B_buf[B_other + (B_base0 + wi + mi, B_base1 + wk + mk)]
                     else:
-                        B_local_buf[j * local_size_b] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi]
+                        B_local_buf[j * local_size_b] = B_buf[B_other + (B_base0 + wk + mk, B_base1 + wi + mi)]

404-464: ⚠️ Potential issue | 🔴 Critical

Same Python 3.11+ syntax issue in non-fp64 ldmatrix_b path.

Apply the same tuple concatenation fix.

🐛 Proposed fix
-        B_other = [r.min for r in B_region.region[:-2]]
+        B_other = tuple(r.min for r in B_region.region[:-2])
         B_stride_last = B_buf.shape[-1]
         ...
 
         `@T.macro`
         def _warp_ldmatrix_b(
             ...
         ):
             ...
             for i in T.serial(warp_cols):
                 ...
                 if ldmatrix_available:
-                    B_shared_buf_elem = B_buf[*B_other, B_base0 + wi, B_base1 + wk] if b_transposed else B_buf[*B_other, B_base0 + wk, B_base1 + wi]
+                    B_shared_buf_elem = B_buf[B_other + (B_base0 + wi, B_base1 + wk)] if b_transposed else B_buf[B_other + (B_base0 + wk, B_base1 + wi)]
                     ...
                 else:
                     for j in T.serial(local_size_b):
                         mi, mk = mma_load_layout(tx, j)
                         if b_transposed:
-                            B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wi + mi, B_base1 + wk + mk]
+                            B_local_buf[i * local_size_b + j] = B_buf[B_other + (B_base0 + wi + mi, B_base1 + wk + mk)]
                         else:
-                            B_local_buf[i * local_size_b + j] = B_buf[*B_other, B_base0 + wk + mk, B_base1 + wi + mi]
+                            B_local_buf[i * local_size_b + j] = B_buf[B_other + (B_base0 + wk + mk, B_base1 + wi + mi)]

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request fixes a bug in the TensorCoreIntrinEmitter class in mma_macro_generator.py where buffer access did not correctly handle input matrices with more than two dimensions. The fix adds support for multi-dimensional buffers by extracting and prepending leading dimension indices when accessing buffer elements.

Changes:

  • Added extraction of leading dimensions (beyond the last 2) into A_other and B_other variables
  • Updated all buffer indexing operations to prepend leading dimensions using tuple concatenation
  • Applied consistently across both ldmatrix_a and ldmatrix_b functions, for both FP64 and general cases, and for both transposed and non-transposed variants

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@SiriusNEO SiriusNEO merged commit c1481eb into tile-ai:main Feb 5, 2026
12 checks passed
@xwhzz xwhzz deleted the fix0205 branch February 9, 2026 09:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants