Skip to content

amd/deepseek_v4 integration 4/N - TilelangAttn 0428#24033

Merged
HaiShaw merged 1 commit into
sgl-project:amd/deepseek_v4from
HaiShaw:amd/deepseek_v4_0428_tilelang_attn
Apr 29, 2026
Merged

amd/deepseek_v4 integration 4/N - TilelangAttn 0428#24033
HaiShaw merged 1 commit into
sgl-project:amd/deepseek_v4from
HaiShaw:amd/deepseek_v4_0428_tilelang_attn

Conversation

@1am9trash
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash commented Apr 29, 2026

Motivation

Update amd/deepseek_v4 integration branch

Following PRs have large set of conflict, we use this PR and upstream amd/deepseek_v4 branch to integrate in parallel.
#23600
#23608

The original flash_mla_with_kvcache_torch launches 101 kernels per call; switching to the tilelang kernel reduces this to 2.
Controlled by export SGLANG_HACK_FLASHMLA_BACKEND=tilelang.

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the BF16 sparse attention kernel with a more efficient 2-stage FP8 attention kernel system consisting of partial and combine kernels. The new implementation supports in-kernel dequantization and dual KV caches. Additionally, it includes a workaround for a tilelang mutation bug and utility functions for cache reinterpretation. Review feedback focuses on improving code maintainability by replacing magic numbers with constants, using more idiomatic PyTorch APIs for integer limits, and addressing the use of deprecated storage methods.

Comment on lines +1404 to +1426
def _build_fp8_combined_view(k_cache: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""
Reinterpret a MODEL1_FP8Sparse KV cache as a contiguous uint32 view.
Input: k_cache (num_blocks, block_size, 1, d_qk) fp8/uint8
— per-block storage also holds scales + padding past d_qk.
Output: (num_blocks, block_pad_u32) uint32 covering the full block
stride. Same storage ashe input, no copy.
"""
k_u8 = k_cache.view(torch.uint8) if k_cache.dtype != torch.uint8 else k_cache
num_blocks = k_u8.shape[0]
block_size = k_u8.shape[1]
block_pad_u32 = k_u8.stride(0) // 4
storage = k_u8.untyped_storage()
flat_u32 = torch.empty(0, dtype=torch.uint32, device=k_u8.device).set_(
storage, 0, (storage.nbytes() // 4,), (1,)
)
k_combined = torch.as_strided(
flat_u32,
size=(num_blocks, block_pad_u32),
stride=(block_pad_u32, 1),
storage_offset=k_u8.storage_offset() // 4,
)
return k_combined, num_blocks, block_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _build_fp8_combined_view uses untyped_storage() which is deprecated in newer PyTorch versions. Consider using storage() or untyped_storage() if you are sure about the version compatibility, but untyped_storage() is generally discouraged in favor of storage() for newer codebases.

return max(tilelang.math.next_power_of_2(head_kv), 16)
return max(_next_power_of_2(head_kv), 16)
_TOPK_LEN_SENTINEL_CACHE: dict = {}
_INT32_MAX = 2**31 - 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using 2**31 - 1 is fine, but torch.iinfo(torch.int32).max is more idiomatic and readable for defining the maximum value of a 32-bit integer.

cur = _TOPK_LEN_SENTINEL_CACHE.get(device)
if cur is None or cur.numel() < batch:
cur = torch.full(
(max(batch, 256),), _INT32_MAX, dtype=torch.int32, device=device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of max(batch, 256) for the sentinel tensor size is a magic number. It should be defined as a named constant at the module level to improve maintainability.

@HaiShaw HaiShaw merged commit 35634dd into sgl-project:amd/deepseek_v4 Apr 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants