-
Notifications
You must be signed in to change notification settings - Fork 450
[Example] Add example for mHC inference kernels. #1684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds two new example modules implementing TileLang-based mHC pre and post operators with Python wrappers, reference implementations, deterministic test-data generators, and CUDA-gated tests; plus a test orchestration module invoking both examples. Changes
Sequence Diagram(s)sequenceDiagram
participant TestRunner as Test Runner
participant PyAPI as Python API (mhc_pre/mhc_post)
participant JIT as TileLang JIT
participant GPU as GPU Kernel
participant Ref as Reference Impl
TestRunner->>PyAPI: call test()/main()
PyAPI->>PyAPI: generate_test_data()
PyAPI->>JIT: compile mhc_*_tilelang(...) (JIT build)
JIT->>GPU: launch compiled kernel with tensors
GPU-->>PyAPI: write output tensor
PyAPI->>Ref: run mhc_*_ref(...) on same inputs
PyAPI->>TestRunner: compare outputs (assert close)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds example implementations for mHC (Multi-Head Context) inference kernels, referencing the arXiv paper 2512.24880. The implementation includes fused CUDA kernels written in TileLang for both the "pre" and "post" blocks of the mHC architecture.
Changes:
- Adds
example_mhc_pre.pywith deeply fused kernels for mHC pre block including RMS normalization, Sinkhorn normalization, and fused GEMM operations - Adds
example_mhc_post.pywith fused kernels for mHC post block operations - Adds
test_example_mhc.pywith test cases for both implementations
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| examples/deepseek_mhc/example_mhc_pre.py | Implements mHC pre-block kernels with fused GEMM+sqrsum and normalization operations, includes reference implementations and test utilities |
| examples/deepseek_mhc/example_mhc_post.py | Implements mHC post-block kernels with matrix operations, includes reference implementation and test utilities |
| examples/deepseek_mhc/test_example_mhc.py | Provides pytest test wrappers for both mHC implementations with CUDA requirements |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| residual: torch.Tensor, | ||
| post_layer_mix: torch.Tensor, | ||
| comb_res_mix: torch.Tensor, | ||
| ) -> torch.Tensor: |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mhc_post function lacks a docstring explaining its purpose, parameters, and return value. The similar mhc_pre function in example_mhc_pre.py has comprehensive documentation (lines 210-229). Add a docstring following the same format to maintain consistency.
| ) -> torch.Tensor: | |
| ) -> torch.Tensor: | |
| """ | |
| Apply the MHC post operator using the TileLang implementation. | |
| This function wraps :func:`mhc_post_tilelang` to compute the mixed | |
| combination of the residual stream and the post-layer activations. | |
| Conceptually, it implements the same computation as :func:`mhc_post_ref` | |
| but executes it via a fused GPU kernel. | |
| Args: | |
| x: Input activations of shape ``(n, h)`` and dtype ``bfloat16``. | |
| residual: Residual stream tensor of shape ``(n, hc, h)`` and | |
| dtype ``bfloat16``. | |
| post_layer_mix: Per-head mixing weights of shape ``(n, hc, 1)`` | |
| and dtype ``float32``. | |
| comb_res_mix: Combination matrix for residual heads of shape | |
| ``(n, hc, hc)`` and dtype ``float32``. | |
| Returns: | |
| torch.Tensor: The mixed output tensor of shape ``(n, hc, h)`` | |
| and dtype ``bfloat16``, matching the shape of ``residual``. | |
| """ |
| def _mhc_post( | ||
| a: T.Tensor((n, hc, hc), T.float32), | ||
| b: T.Tensor((n, hc, h), T.bfloat16), | ||
| c: T.Tensor((n, hc), T.float32), | ||
| d: T.Tensor((n, h), T.bfloat16), | ||
| x: T.Tensor((n, hc, h), T.bfloat16), | ||
| ) -> None: |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter names a, b, c, d, and x are not descriptive. Based on the reference implementation at line 79-86, these represent comb_res_mix, residual, post_layer_mix, x (input), and output respectively. Using descriptive names would improve code readability and maintainability.
| for n1 in [512, 1024, 2048, 8192]: | ||
| for hidden_size in [1280, 2560, 4096]: | ||
| for hc_mult in [4]: | ||
| test(n=n1, hidden_size=hidden_size, hc_mult=hc_mult) |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file is missing a trailing newline at the end. Most Python style guides recommend files end with a newline character.
| def main(): | ||
| for n in [4096]: | ||
| for h in [1280, 2560, 7168]: | ||
| test(n=n, h=h) |
Copilot
AI
Jan 16, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file is missing a trailing newline at the end. Most Python style guides recommend files end with a newline character.
|
LGTM! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_pre.py`:
- Around line 159-198: The kernel currently writes full token_block slices
unconditionally (inside T.Kernel with loop var px), causing OOB on the final
partial block; update the code that writes sqrsum and out to guard the tail:
compute tail_len = num_tokens - px * token_block (clamped to [0, token_block])
and use that to limit writes/parallel loops (e.g., replace full token_block
Parallel loops that write sqrsum[px * token_block + i] and out[px * token_block
+ i, j] with loops or conditionals that only iterate/assign when i < tail_len),
and similarly ensure reduce_sum/sqrsum_l and copying into out_frag only
contribute for i < tail_len; use the existing px, token_block, sqrsum_part,
sqrsum, out_frag, and out symbols to locate and restrict the tail writes.
🧹 Nitpick comments (3)
examples/deepseek_mhc/example_mhc_post.py (2)
56-64: Add shape/dtype guards to prevent silent mismatches.
Line 56–63 currently trusts caller shapes; a badpost_layer_mixorcomb_res_mixshape will quietly produce wrong results. Consider cheap asserts like inmhc_pre.♻️ Suggested guardrails
def mhc_post( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + assert residual.dtype == torch.bfloat16 + assert post_layer_mix.dtype == torch.float32 + assert comb_res_mix.dtype == torch.float32 + assert x.shape[0] == residual.shape[0] == post_layer_mix.shape[0] == comb_res_mix.shape[0] + assert post_layer_mix.shape[-1] == 1 + assert comb_res_mix.shape[-1] == comb_res_mix.shape[-2] == residual.shape[-2] + assert x.shape[-1] == residual.shape[-1] out = torch.empty_like(residual) mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1]) return out
99-110: Optional: guard CUDA-only test path.
generate_test_datadefaults to CUDA; runningmain()on a CPU-only machine will throw. A small guard keeps the example friendlier.♻️ Optional guard
def main(): + if not torch.cuda.is_available(): + print("CUDA not available; skipping mhc_post example.") + return for n in [4096]: for h in [1280, 2560, 7168]: test(n=n, h=h)examples/deepseek_mhc/example_mhc_pre.py (1)
345-388: Consider parameterizing the device instead of hard-coding CUDA.
Line 360 fixesdevice = "cuda". Passing adeviceargument (default"cuda") makes the example more reusable.♻️ Optional parameterization
def generate_test_data( n: int, hc_mult: int, hidden_size: int, rms_eps: float = 1e-6, hc_pre_eps: float = 1e-6, hc_sinkhorn_eps: float = 1e-6, hc_post_mult_value: float = 1.0, sinkhorn_repeat: int = 10, + device: str = "cuda", ) -> dict[str, torch.Tensor | float]: @@ - device = "cuda" - residual = ( torch.randn((n, hc_mult, hidden_size), dtype=torch.float, device=device)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/deepseek_mhc/example_mhc_post.pyexamples/deepseek_mhc/example_mhc_pre.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/deepseek_mhc/example_mhc_post.py
🧬 Code graph analysis (1)
examples/deepseek_mhc/example_mhc_post.py (3)
tilelang/language/symbolics.py (1)
dynamic(12-29)tilelang/language/allocate.py (2)
alloc_shared(39-54)alloc_fragment(71-82)tilelang/language/loop.py (1)
Parallel(13-72)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (8)
examples/deepseek_mhc/example_mhc_post.py (3)
16-53: Kernel tiling & staging look sound.
The gcd-based block sizing (Line 21) keepshdivisible byh_blk, and the pipelined block loop (Line 41–53) cleanly stages shared/local memory.
67-75: Reference path is clear and matches the intended broadcast.
Thebmm+ broadcasted mix (Line 73–74) mirrors the kernel math well.
77-96: Deterministic test data setup looks good.
Line 83–90 seeds RNG and builds consistent shapes/dtypes for the test path.examples/deepseek_mhc/example_mhc_pre.py (5)
15-137: Fused kernel structure is consistent.
The staged mix computation and split of post/comb vs pre branches reads cleanly and matches the reference flow.
201-296: Wrapper validation and reshaping look solid.
The dtype/shape checks (Line 234–249) and output reshaping (Line 292–294) are clear and robust.
299-305: Reference Sinkhorn helper is straightforward.
Loop structure and eps handling are easy to follow.
308-342: Reference path mirrors fused computation well.
The mix construction and normalization steps align with the fused kernel flow.
391-416: Test harness wiring looks good.
The test loops cover multiplen/hidden_sizecombos and reuse shared data generation.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: | ||
| out_frag = T.alloc_fragment((token_block, 32), T.float32) | ||
| sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) | ||
| T.clear(out_frag) | ||
| T.clear(sqrsum_part) | ||
| for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): | ||
| x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) | ||
| fn_smem = T.alloc_shared((32, hidden_block), T.float32) | ||
|
|
||
| T.annotate_layout({x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)}) | ||
|
|
||
| T.copy(x[px * token_block, pz * hidden_block], x_smem_16) | ||
| T.copy(fn[0, pz * hidden_block], fn_smem) | ||
|
|
||
| x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) | ||
| T.copy(x_smem_16, x_frag_16) | ||
| x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) | ||
| T.copy(x_frag_16, x_frag) | ||
|
|
||
| for jj in T.serial(hidden_block // 4): | ||
| for i, j in T.Parallel(token_block, 4): | ||
| sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] | ||
|
|
||
| # should be TF32 gemm | ||
| T.gemm( | ||
| x_frag, | ||
| fn_smem, | ||
| out_frag, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| wg_wait=0, | ||
| clear_accum=False, | ||
| ) | ||
| sqrsum_l = T.alloc_fragment(token_block, T.float32) | ||
| T.reduce_sum(sqrsum_part, sqrsum_l) | ||
| for i in T.Parallel(token_block): | ||
| sqrsum[px * token_block + i] = sqrsum_l[i] | ||
| for i, j in T.Parallel(token_block, 32): | ||
| if j < hc_mult3: | ||
| out[px * token_block + i, j] = out_frag[i, j] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -name "example_mhc_pre.py" -type fRepository: tile-ai/tilelang
Length of output: 102
🏁 Script executed:
fd "example_mhc_pre.py"Repository: tile-ai/tilelang
Length of output: 100
🏁 Script executed:
find . -path "*deepseek_mhc*" -type fRepository: tile-ai/tilelang
Length of output: 190
🏁 Script executed:
cat -n examples/deepseek_mhc/example_mhc_pre.pyRepository: tile-ai/tilelang
Length of output: 17925
🏁 Script executed:
# Check for any recent changes or guards added to this kernel
git log --oneline -20 -- examples/deepseek_mhc/example_mhc_pre.pyRepository: tile-ai/tilelang
Length of output: 91
🏁 Script executed:
# Check if token_block is ever validated elsewhere in the codebase
rg "token_block" examples/deepseek_mhc/ -A 2 -B 2Repository: tile-ai/tilelang
Length of output: 3439
Guard tail blocks to avoid OOB when num_tokens % token_block != 0.
The kernel writes token_block elements unconditionally in lines 194–195 and 196–198. With T.ceildiv(num_tokens, token_block), the last block can write past the end unless num_tokens is a multiple of token_block (32). For example, with num_tokens=33 and token_block=32, the last iteration writes to indices 32–63, but output tensors only have 33 elements.
🐛 Minimal fix: assert the precondition in the wrapper
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
+ assert num_tokens % 32 == 0, (
+ "mhc_pre_gemm_sqrsum_tilelang requires num_tokens to be a multiple of token_block=32"
+ )
fn_flat = fn🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_pre.py` around lines 159 - 198, The kernel
currently writes full token_block slices unconditionally (inside T.Kernel with
loop var px), causing OOB on the final partial block; update the code that
writes sqrsum and out to guard the tail: compute tail_len = num_tokens - px *
token_block (clamped to [0, token_block]) and use that to limit writes/parallel
loops (e.g., replace full token_block Parallel loops that write sqrsum[px *
token_block + i] and out[px * token_block + i, j] with loops or conditionals
that only iterate/assign when i < tail_len), and similarly ensure
reduce_sum/sqrsum_l and copying into out_frag only contribute for i < tail_len;
use the existing px, token_block, sqrsum_part, sqrsum, out_frag, and out symbols
to locate and restrict the tail writes.
| T.reduce_max(cm, row_max, dim=1) | ||
| for j, k in T.Parallel(hc_mult, hc_mult): | ||
| cm[j, k] = T.exp(cm[j, k] - row_max[j]) | ||
| T.reduce_sum(cm, row_sum, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe 85-103 can be merged in a sinkhorm_repeattimes loop?
See the arXiv paper.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.