Skip to content

Conversation

@Elevator14B
Copy link
Collaborator

@Elevator14B Elevator14B commented Jan 16, 2026

See the arXiv paper.

Summary by CodeRabbit

  • New Features
    • Added MHC pre-processing and deeply fused MHC post-processing operators with validated reference implementations, deterministic test-data generators, and self-test harnesses for accuracy checks.
  • Tests
    • Added CUDA-enabled tests that run both MHC operators end-to-end to verify numerical correctness across multiple configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds two new example modules implementing TileLang-based mHC pre and post operators with Python wrappers, reference implementations, deterministic test-data generators, and CUDA-gated tests; plus a test orchestration module invoking both examples.

Changes

Cohort / File(s) Summary
MHC Post-Processing Kernel
examples/deepseek_mhc/example_mhc_post.py
New TileLang JIT kernel mhc_post_tilelang and public API mhc_post(); reference mhc_post_ref(); deterministic generate_test_data(); test() and main() harnesses.
MHC Pre-Block Fused Kernels
examples/deepseek_mhc/example_mhc_pre.py
New TileLang kernel builders mhc_pre_big_fuse_tilelang() and mhc_pre_gemm_sqrsum_tilelang(), public API mhc_pre(), reference mhc_pre_ref(), sinkhorn_normalize_ref(), test-data generator, test() and main() harnesses.
Test Orchestration
examples/deepseek_mhc/test_example_mhc.py
New test module with test_mhc_post() and test_mhc_pre() decorated with CUDA requirement; script entry to run tilelang.testing main.

Sequence Diagram(s)

sequenceDiagram
  participant TestRunner as Test Runner
  participant PyAPI as Python API (mhc_pre/mhc_post)
  participant JIT as TileLang JIT
  participant GPU as GPU Kernel
  participant Ref as Reference Impl

  TestRunner->>PyAPI: call test()/main()
  PyAPI->>PyAPI: generate_test_data()
  PyAPI->>JIT: compile mhc_*_tilelang(...) (JIT build)
  JIT->>GPU: launch compiled kernel with tensors
  GPU-->>PyAPI: write output tensor
  PyAPI->>Ref: run mhc_*_ref(...) on same inputs
  PyAPI->>TestRunner: compare outputs (assert close)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Suggested labels

enhancement

Suggested reviewers

  • LeiWang1999

Poem

🐰 I tiled and threaded through the night,
Shared buffers hummed beneath the light,
Pre and post danced, tests in queue,
Reference checks said "works, it's true!"
Hoppity joy — a CUDA delight. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.25% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Example] Add example for mHC inference kernels' directly and clearly summarizes the main change: adding example code for mHC inference kernels. It is specific, concise, and accurately reflects the primary purpose of the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds example implementations for mHC (Multi-Head Context) inference kernels, referencing the arXiv paper 2512.24880. The implementation includes fused CUDA kernels written in TileLang for both the "pre" and "post" blocks of the mHC architecture.

Changes:

  • Adds example_mhc_pre.py with deeply fused kernels for mHC pre block including RMS normalization, Sinkhorn normalization, and fused GEMM operations
  • Adds example_mhc_post.py with fused kernels for mHC post block operations
  • Adds test_example_mhc.py with test cases for both implementations

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
examples/deepseek_mhc/example_mhc_pre.py Implements mHC pre-block kernels with fused GEMM+sqrsum and normalization operations, includes reference implementations and test utilities
examples/deepseek_mhc/example_mhc_post.py Implements mHC post-block kernels with matrix operations, includes reference implementation and test utilities
examples/deepseek_mhc/test_example_mhc.py Provides pytest test wrappers for both mHC implementations with CUDA requirements

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mhc_post function lacks a docstring explaining its purpose, parameters, and return value. The similar mhc_pre function in example_mhc_pre.py has comprehensive documentation (lines 210-229). Add a docstring following the same format to maintain consistency.

Suggested change
) -> torch.Tensor:
) -> torch.Tensor:
"""
Apply the MHC post operator using the TileLang implementation.
This function wraps :func:`mhc_post_tilelang` to compute the mixed
combination of the residual stream and the post-layer activations.
Conceptually, it implements the same computation as :func:`mhc_post_ref`
but executes it via a fused GPU kernel.
Args:
x: Input activations of shape ``(n, h)`` and dtype ``bfloat16``.
residual: Residual stream tensor of shape ``(n, hc, h)`` and
dtype ``bfloat16``.
post_layer_mix: Per-head mixing weights of shape ``(n, hc, 1)``
and dtype ``float32``.
comb_res_mix: Combination matrix for residual heads of shape
``(n, hc, hc)`` and dtype ``float32``.
Returns:
torch.Tensor: The mixed output tensor of shape ``(n, hc, h)``
and dtype ``bfloat16``, matching the shape of ``residual``.
"""

Copilot uses AI. Check for mistakes.
Comment on lines 24 to 30
def _mhc_post(
a: T.Tensor((n, hc, hc), T.float32),
b: T.Tensor((n, hc, h), T.bfloat16),
c: T.Tensor((n, hc), T.float32),
d: T.Tensor((n, h), T.bfloat16),
x: T.Tensor((n, hc, h), T.bfloat16),
) -> None:
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter names a, b, c, d, and x are not descriptive. Based on the reference implementation at line 79-86, these represent comb_res_mix, residual, post_layer_mix, x (input), and output respectively. Using descriptive names would improve code readability and maintainability.

Copilot uses AI. Check for mistakes.
for n1 in [512, 1024, 2048, 8192]:
for hidden_size in [1280, 2560, 4096]:
for hc_mult in [4]:
test(n=n1, hidden_size=hidden_size, hc_mult=hc_mult)
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file is missing a trailing newline at the end. Most Python style guides recommend files end with a newline character.

Copilot uses AI. Check for mistakes.
def main():
for n in [4096]:
for h in [1280, 2560, 7168]:
test(n=n, h=h)
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file is missing a trailing newline at the end. Most Python style guides recommend files end with a newline character.

Copilot uses AI. Check for mistakes.
@LeiWang1999
Copy link
Member

LGTM!

@LeiWang1999 LeiWang1999 merged commit 60050f2 into tile-ai:main Jan 16, 2026
5 of 6 checks passed
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_pre.py`:
- Around line 159-198: The kernel currently writes full token_block slices
unconditionally (inside T.Kernel with loop var px), causing OOB on the final
partial block; update the code that writes sqrsum and out to guard the tail:
compute tail_len = num_tokens - px * token_block (clamped to [0, token_block])
and use that to limit writes/parallel loops (e.g., replace full token_block
Parallel loops that write sqrsum[px * token_block + i] and out[px * token_block
+ i, j] with loops or conditionals that only iterate/assign when i < tail_len),
and similarly ensure reduce_sum/sqrsum_l and copying into out_frag only
contribute for i < tail_len; use the existing px, token_block, sqrsum_part,
sqrsum, out_frag, and out symbols to locate and restrict the tail writes.
🧹 Nitpick comments (3)
examples/deepseek_mhc/example_mhc_post.py (2)

56-64: Add shape/dtype guards to prevent silent mismatches.
Line 56–63 currently trusts caller shapes; a bad post_layer_mix or comb_res_mix shape will quietly produce wrong results. Consider cheap asserts like in mhc_pre.

♻️ Suggested guardrails
 def mhc_post(
     x: torch.Tensor,
     residual: torch.Tensor,
     post_layer_mix: torch.Tensor,
     comb_res_mix: torch.Tensor,
 ) -> torch.Tensor:
+    assert x.dtype == torch.bfloat16
+    assert residual.dtype == torch.bfloat16
+    assert post_layer_mix.dtype == torch.float32
+    assert comb_res_mix.dtype == torch.float32
+    assert x.shape[0] == residual.shape[0] == post_layer_mix.shape[0] == comb_res_mix.shape[0]
+    assert post_layer_mix.shape[-1] == 1
+    assert comb_res_mix.shape[-1] == comb_res_mix.shape[-2] == residual.shape[-2]
+    assert x.shape[-1] == residual.shape[-1]
     out = torch.empty_like(residual)
     mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1])
     return out

99-110: Optional: guard CUDA-only test path.
generate_test_data defaults to CUDA; running main() on a CPU-only machine will throw. A small guard keeps the example friendlier.

♻️ Optional guard
 def main():
+    if not torch.cuda.is_available():
+        print("CUDA not available; skipping mhc_post example.")
+        return
     for n in [4096]:
         for h in [1280, 2560, 7168]:
             test(n=n, h=h)
examples/deepseek_mhc/example_mhc_pre.py (1)

345-388: Consider parameterizing the device instead of hard-coding CUDA.
Line 360 fixes device = "cuda". Passing a device argument (default "cuda") makes the example more reusable.

♻️ Optional parameterization
 def generate_test_data(
     n: int,
     hc_mult: int,
     hidden_size: int,
     rms_eps: float = 1e-6,
     hc_pre_eps: float = 1e-6,
     hc_sinkhorn_eps: float = 1e-6,
     hc_post_mult_value: float = 1.0,
     sinkhorn_repeat: int = 10,
+    device: str = "cuda",
 ) -> dict[str, torch.Tensor | float]:
@@
-    device = "cuda"
-
     residual = (
         torch.randn((n, hc_mult, hidden_size), dtype=torch.float, device=device)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d7f524e and 7b1f00b.

📒 Files selected for processing (2)
  • examples/deepseek_mhc/example_mhc_post.py
  • examples/deepseek_mhc/example_mhc_pre.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/deepseek_mhc/example_mhc_post.py
🧬 Code graph analysis (1)
examples/deepseek_mhc/example_mhc_post.py (3)
tilelang/language/symbolics.py (1)
  • dynamic (12-29)
tilelang/language/allocate.py (2)
  • alloc_shared (39-54)
  • alloc_fragment (71-82)
tilelang/language/loop.py (1)
  • Parallel (13-72)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
examples/deepseek_mhc/example_mhc_post.py (3)

16-53: Kernel tiling & staging look sound.
The gcd-based block sizing (Line 21) keeps h divisible by h_blk, and the pipelined block loop (Line 41–53) cleanly stages shared/local memory.


67-75: Reference path is clear and matches the intended broadcast.
The bmm + broadcasted mix (Line 73–74) mirrors the kernel math well.


77-96: Deterministic test data setup looks good.
Line 83–90 seeds RNG and builds consistent shapes/dtypes for the test path.

examples/deepseek_mhc/example_mhc_pre.py (5)

15-137: Fused kernel structure is consistent.
The staged mix computation and split of post/comb vs pre branches reads cleanly and matches the reference flow.


201-296: Wrapper validation and reshaping look solid.
The dtype/shape checks (Line 234–249) and output reshaping (Line 292–294) are clear and robust.


299-305: Reference Sinkhorn helper is straightforward.
Loop structure and eps handling are easy to follow.


308-342: Reference path mirrors fused computation well.
The mix construction and normalization steps align with the fused kernel flow.


391-416: Test harness wiring looks good.
The test loops cover multiple n/hidden_size combos and reuse shared data generation.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +159 to +198
with T.Kernel(T.ceildiv(num_tokens, token_block)) as px:
out_frag = T.alloc_fragment((token_block, 32), T.float32)
sqrsum_part = T.alloc_fragment((token_block, 4), T.float32)
T.clear(out_frag)
T.clear(sqrsum_part)
for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2):
x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16)
fn_smem = T.alloc_shared((32, hidden_block), T.float32)

T.annotate_layout({x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)})

T.copy(x[px * token_block, pz * hidden_block], x_smem_16)
T.copy(fn[0, pz * hidden_block], fn_smem)

x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16)
T.copy(x_smem_16, x_frag_16)
x_frag = T.alloc_fragment((token_block, hidden_block), T.float32)
T.copy(x_frag_16, x_frag)

for jj in T.serial(hidden_block // 4):
for i, j in T.Parallel(token_block, 4):
sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j]

# should be TF32 gemm
T.gemm(
x_frag,
fn_smem,
out_frag,
transpose_A=False,
transpose_B=True,
wg_wait=0,
clear_accum=False,
)
sqrsum_l = T.alloc_fragment(token_block, T.float32)
T.reduce_sum(sqrsum_part, sqrsum_l)
for i in T.Parallel(token_block):
sqrsum[px * token_block + i] = sqrsum_l[i]
for i, j in T.Parallel(token_block, 32):
if j < hc_mult3:
out[px * token_block + i, j] = out_frag[i, j]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "example_mhc_pre.py" -type f

Repository: tile-ai/tilelang

Length of output: 102


🏁 Script executed:

fd "example_mhc_pre.py"

Repository: tile-ai/tilelang

Length of output: 100


🏁 Script executed:

find . -path "*deepseek_mhc*" -type f

Repository: tile-ai/tilelang

Length of output: 190


🏁 Script executed:

cat -n examples/deepseek_mhc/example_mhc_pre.py

Repository: tile-ai/tilelang

Length of output: 17925


🏁 Script executed:

# Check for any recent changes or guards added to this kernel
git log --oneline -20 -- examples/deepseek_mhc/example_mhc_pre.py

Repository: tile-ai/tilelang

Length of output: 91


🏁 Script executed:

# Check if token_block is ever validated elsewhere in the codebase
rg "token_block" examples/deepseek_mhc/ -A 2 -B 2

Repository: tile-ai/tilelang

Length of output: 3439


Guard tail blocks to avoid OOB when num_tokens % token_block != 0.

The kernel writes token_block elements unconditionally in lines 194–195 and 196–198. With T.ceildiv(num_tokens, token_block), the last block can write past the end unless num_tokens is a multiple of token_block (32). For example, with num_tokens=33 and token_block=32, the last iteration writes to indices 32–63, but output tensors only have 33 elements.

🐛 Minimal fix: assert the precondition in the wrapper
     residual_flat = residual.view(-1, hc_mult, hidden_size)
     num_tokens = residual_flat.shape[0]
+    assert num_tokens % 32 == 0, (
+        "mhc_pre_gemm_sqrsum_tilelang requires num_tokens to be a multiple of token_block=32"
+    )
     fn_flat = fn
🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_pre.py` around lines 159 - 198, The kernel
currently writes full token_block slices unconditionally (inside T.Kernel with
loop var px), causing OOB on the final partial block; update the code that
writes sqrsum and out to guard the tail: compute tail_len = num_tokens - px *
token_block (clamped to [0, token_block]) and use that to limit writes/parallel
loops (e.g., replace full token_block Parallel loops that write sqrsum[px *
token_block + i] and out[px * token_block + i, j] with loops or conditionals
that only iterate/assign when i < tail_len), and similarly ensure
reduce_sum/sqrsum_l and copying into out_frag only contribute for i < tail_len;
use the existing px, token_block, sqrsum_part, sqrsum, out_frag, and out symbols
to locate and restrict the tail writes.

T.reduce_max(cm, row_max, dim=1)
for j, k in T.Parallel(hc_mult, hc_mult):
cm[j, k] = T.exp(cm[j, k] - row_max[j])
T.reduce_sum(cm, row_sum, dim=1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe 85-103 can be merged in a sinkhorm_repeattimes loop?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants