amd/deepseek_v4 integration 4/N - TilelangAttn 0428#24033
Merged
HaiShaw merged 1 commit intoApr 29, 2026
Merged
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request replaces the BF16 sparse attention kernel with a more efficient 2-stage FP8 attention kernel system consisting of partial and combine kernels. The new implementation supports in-kernel dequantization and dual KV caches. Additionally, it includes a workaround for a tilelang mutation bug and utility functions for cache reinterpretation. Review feedback focuses on improving code maintainability by replacing magic numbers with constants, using more idiomatic PyTorch APIs for integer limits, and addressing the use of deprecated storage methods.
Comment on lines
+1404
to
+1426
| def _build_fp8_combined_view(k_cache: torch.Tensor) -> Tuple[torch.Tensor, int, int]: | ||
| """ | ||
| Reinterpret a MODEL1_FP8Sparse KV cache as a contiguous uint32 view. | ||
| Input: k_cache (num_blocks, block_size, 1, d_qk) fp8/uint8 | ||
| — per-block storage also holds scales + padding past d_qk. | ||
| Output: (num_blocks, block_pad_u32) uint32 covering the full block | ||
| stride. Same storage ashe input, no copy. | ||
| """ | ||
| k_u8 = k_cache.view(torch.uint8) if k_cache.dtype != torch.uint8 else k_cache | ||
| num_blocks = k_u8.shape[0] | ||
| block_size = k_u8.shape[1] | ||
| block_pad_u32 = k_u8.stride(0) // 4 | ||
| storage = k_u8.untyped_storage() | ||
| flat_u32 = torch.empty(0, dtype=torch.uint32, device=k_u8.device).set_( | ||
| storage, 0, (storage.nbytes() // 4,), (1,) | ||
| ) | ||
| k_combined = torch.as_strided( | ||
| flat_u32, | ||
| size=(num_blocks, block_pad_u32), | ||
| stride=(block_pad_u32, 1), | ||
| storage_offset=k_u8.storage_offset() // 4, | ||
| ) | ||
| return k_combined, num_blocks, block_size |
Contributor
There was a problem hiding this comment.
| return max(tilelang.math.next_power_of_2(head_kv), 16) | ||
| return max(_next_power_of_2(head_kv), 16) | ||
| _TOPK_LEN_SENTINEL_CACHE: dict = {} | ||
| _INT32_MAX = 2**31 - 1 |
Contributor
| cur = _TOPK_LEN_SENTINEL_CACHE.get(device) | ||
| if cur is None or cur.numel() < batch: | ||
| cur = torch.full( | ||
| (max(batch, 256),), _INT32_MAX, dtype=torch.int32, device=device |
Contributor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Update amd/deepseek_v4 integration branch
Following PRs have large set of conflict, we use this PR and upstream amd/deepseek_v4 branch to integrate in parallel.
#23600
#23608
The original
flash_mla_with_kvcache_torchlaunches 101 kernels per call; switching to the tilelang kernel reduces this to 2.Controlled by
export SGLANG_HACK_FLASHMLA_BACKEND=tilelang.Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci