amd/deepseek_v4 integration 5/N - indexer TilelangAttn 0428#24050
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a TileLang-based FP8 paged MQA logits kernel and its corresponding wrapper function. Key changes include the implementation of the fp8_paged_mqa_logits_kernel with shared memory optimization and pipelining, as well as updates to FP8 data type definitions. Feedback highlights a potential out-of-bounds write when max_seq_len is not aligned with the block size, a logic contradiction in the clean_logits parameter's default value and assertion, a redundant variable assignment, and the use of a hardcoded compute unit count which may affect performance across different GPU architectures.
| T.reduce_sum(logits, logits_sum, dim=1) | ||
| for j2 in T.Parallel(B): | ||
| logits_sum[j2] *= k_s_frag[j2] | ||
| T.copy(logits_sum, o[bx, i * B]) |
There was a problem hiding this comment.
Potential out-of-bounds write: T.copy(logits_sum, o[bx, i * B]) writes B (block_size) elements to the output buffer. If max_seq_len (symbolic S) is not a multiple of block_size, the last block write will exceed the tensor bounds. Consider adding an assertion in the wrapper function to ensure max_seq_len is aligned to block_size or implementing a masked copy in the kernel.
| page_table: torch.Tensor, | ||
| deep_gemm_metadata: Any, | ||
| max_seq_len: int, | ||
| clean_logits: bool = True, |
There was a problem hiding this comment.
|
|
||
| logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) | ||
|
|
||
| NUM_CU = 256 |
| logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) | ||
|
|
||
| NUM_CU = 256 | ||
| split_kv = split_kv = max(1, min(max_seq_len // block_size, NUM_CU // batch_size)) |
Motivation
Update amd/deepseek_v4 integration branch
Following PRs have large set of conflict, we use this PR and upstream amd/deepseek_v4 branch to integrate in parallel.
#23600
#23608
The original
fp8_paged_mqa_logits_torchlaunches 12 kernels per call; switching to this tilelang kernel reduces this to 1.Controlled by
export SGLANG_OPT_USE_TILELANG_INDEXER=true.Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci