Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions examples/deepseek_mhc/example_mhc_post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import math

import torch

import tilelang
import tilelang.language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
)
def mhc_post_tilelang(a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024) -> tilelang.JITKernel:
# rename for shorter code
n = T.dynamic("num_tokens")
h = hidden

h_blk = math.gcd(hidden, h_blk)
a: T.Tensor((n, hc, hc), T.float32)
b: T.Tensor((n, hc, h), T.bfloat16)
c: T.Tensor((n, hc), T.float32)
d: T.Tensor((n, h), T.bfloat16)
x: T.Tensor((n, hc, h), T.bfloat16)
with T.Kernel(n, threads=n_thr) as i_n:
x_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
b_shared = T.alloc_shared((hc, h_blk), T.bfloat16)
d_shared = T.alloc_shared(h_blk, T.bfloat16)

x_local = T.alloc_fragment((hc, h_blk), T.float32)
b_local = T.alloc_fragment((hc, h_blk), T.float32)
d_local = T.alloc_fragment(h_blk, T.float32)

a_local = T.alloc_fragment((hc, hc), T.float32)
c_local = T.alloc_fragment(hc, T.float32)
T.copy(a[i_n, 0, 0], a_local)
T.copy(c[i_n, 0], c_local)

for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2):
T.copy(b[i_n, 0, i0_h * h_blk], b_shared)
T.copy(d[i_n, i0_h * h_blk], d_shared)

T.copy(b_shared, b_local)
T.copy(d_shared, d_local)
for i_hco, i1_h in T.Parallel(hc, h_blk):
x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h]
for i_hci in T.serial(hc):
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
T.copy(x_local, x_shared)

T.copy(x_shared, x[i_n, 0, i0_h * h_blk])


def mhc_post(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mhc_post function lacks a docstring explaining its purpose, parameters, and return value. The similar mhc_pre function in example_mhc_pre.py has comprehensive documentation (lines 210-229). Add a docstring following the same format to maintain consistency.

Suggested change
) -> torch.Tensor:
) -> torch.Tensor:
"""
Apply the MHC post operator using the TileLang implementation.
This function wraps :func:`mhc_post_tilelang` to compute the mixed
combination of the residual stream and the post-layer activations.
Conceptually, it implements the same computation as :func:`mhc_post_ref`
but executes it via a fused GPU kernel.
Args:
x: Input activations of shape ``(n, h)`` and dtype ``bfloat16``.
residual: Residual stream tensor of shape ``(n, hc, h)`` and
dtype ``bfloat16``.
post_layer_mix: Per-head mixing weights of shape ``(n, hc, 1)``
and dtype ``float32``.
comb_res_mix: Combination matrix for residual heads of shape
``(n, hc, hc)`` and dtype ``float32``.
Returns:
torch.Tensor: The mixed output tensor of shape ``(n, hc, h)``
and dtype ``bfloat16``, matching the shape of ``residual``.
"""

Copilot uses AI. Check for mistakes.
out = torch.empty_like(residual)
mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1])
return out


def mhc_post_ref(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
term2 = torch.bmm(comb_res_mix.mT, residual.float())
return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16()


def generate_test_data(
n: int,
h: int,
hc_mult: int,
device: str = "cuda",
) -> dict[str, torch.Tensor]:
"""Generate test data for post operator."""
torch.random.manual_seed(42)

x = torch.randn((n, h), dtype=torch.bfloat16, device=device)
residual = torch.randn((n, hc_mult, h), dtype=torch.bfloat16, device=device)
post_layer_mix = torch.randn((n, hc_mult, 1), dtype=torch.float32, device=device)
comb_res_mix = torch.randn((n, hc_mult, hc_mult), dtype=torch.float32, device=device)

return {
"x": x,
"residual": residual,
"post_layer_mix": post_layer_mix,
"comb_res_mix": comb_res_mix,
}


def test(n: int, h: int) -> None:
print(f"Testing mhc_post with {n=} {h=}")
test_data = generate_test_data(n=n, h=h, hc_mult=4)
out_tl = mhc_post(**test_data)
out_ref = mhc_post_ref(**test_data)
torch.testing.assert_close(out_tl, out_ref)


def main():
for n in [4096]:
for h in [1280, 2560, 7168]:
test(n=n, h=h)
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file is missing a trailing newline at the end. Most Python style guides recommend files end with a newline character.

Copilot uses AI. Check for mistakes.


if __name__ == "__main__":
main()
Loading
Loading