[Perf & Feat] Add deepseek32 topk opt : Introduction to the ultra low latency attention#23761
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a distributed TopK Radix Indexer implemented in CUDA to accelerate decoding for DeepSeek V3.2, including specialized kernels for ragged and paged attention layouts and a Python JIT wrapper. The review identified several critical issues in the CUDA implementation: a potential shared memory buffer overflow due to missing bounds checks when elements exceed the bin capacity, a logic error in the transform kernels where the unrolled copy loop only handles half of the required TopK indices, and hardcoded sequence length constants that limit the kernel's flexibility. Additionally, the radix conversion logic currently assumes pre-scaled input, which may lead to incorrect sorting for standard attention logit ranges.
| const unsigned int pos = ::atomicAdd(&s_num_input[0], 1); | ||
| // if (pos < SMEM_INPUT_SIZE) { | ||
| s_input[0][pos] = val_scale; |
| const auto pos_0 = s_indices[idx_0]; | ||
| dst_page_entry[idx_0] = src_page_entry[pos_0]; | ||
| const auto idx_1 = tid + kThreadsPerBlock; | ||
| const auto pos_1 = s_indices[idx_1]; | ||
| dst_page_entry[idx_1] = src_page_entry[pos_1]; | ||
| } |
There was a problem hiding this comment.
The copy loop only handles idx_0 and idx_1, which corresponds to 2 * kThreadsPerBlock = 1024 elements. Since TopK is hardcoded to 2048, this loop will only copy half of the indices to the destination page table. You should use a loop or additional unrolled steps to cover the full TopK range.
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
dst_page_entry[i] = src_page_entry[s_indices[i]];
}
|
|
||
| // TODO (yiakwy) : test | ||
| __device__ __forceinline__ auto convert_to_monotonic_8bit(float x) -> uint8_t { | ||
| int bin = __float2int_rd(x); |
There was a problem hiding this comment.
The current implementation of convert_to_monotonic_8bit using __float2int_rd(x) assumes that the input scores are already scaled to the range [0, 255]. In a production environment, attention logits or probabilities are typically in a different range (e.g., (-inf, inf) or [0, 1]). If the input is not pre-scaled, this will result in all values falling into the same bin, breaking the radix sort logic. Consider using a proper bit-manipulation approach for floating-point radix sort to ensure monotonicity across the full range of float values.
There was a problem hiding this comment.
Yep topk data indexer should be normalized in 0.1, hence we mapping it to 0, 256 for . This dramatically reduced radix rounds from 4 to 1 or 2.
| const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional<at::Tensor> row_starts_opt) { | ||
| CHECK_CUDA(score); | ||
| const auto B = score.size(0); | ||
| const auto L = 65536; // score.size(1); |
There was a problem hiding this comment.
|
Hi. We also implement a low-latency topk for DSv4 where k <= 1024. Could you please refer to that and maybe also port that? I believe that could be useful. |
|
Have you compared the kernel implementations in SGLang? I see your comparison only includes Torch's native implementation. |
Sure, definitely, TopK 1024 could be supported. This week we are busy in supporting DeepSeek V4 RL, @DarkSharpness Thanks for the reminding ! Ref : https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/config.json#L18 |
Yes, the column with name "radix" represents "fast_topk_v2" in sglang . |
|
Hi, we can later integrate this as a backend for #22851 |
|
Features with TVM-FFI intreface will be upstreamed with support of Top2048/1024/512 Wish your attention @hnyls2002, cc @merrymercy |
Motivation
DeeSeek V4's 1 M-context makes decoding stage topk introduced in DS32 a prohibitive bottlenect . We reduce the latency by :
(1) First we compute historgram in parallel to reduce collision rates per block and then accumulate via NoC network before N-ways prefix sum and prove this is effectively method to reduce latency in a throughput oriented hardware.
(2) Second, we enhance the linear mapping properties for radix sort in NSA problem for reduction iteration. Instead of traditional top 8/11/13 bits of IEEE FP32, FP16 format, we redesign a linear mapping such that$b(y) > b(x)$ $x > y$ . With this linear mapping design, we greatly reduced per block elements dropped in the bin to determine the residule numbers. This further facilitate cache friendly visiting over 1-M context length : we hence enable less SMEM revisiting more elements.
, naturally deducing
(3) When remainder elements reduced to 8/16, we can simply use CAS operations to performa a parallel sorting in few cycles. This further reduce the latency overhead in the last round.
See details from https://github.com/yiakwy-xpu-ml-framework-team/flash-float-jit-kernels
The work is adpation from flash-float-jit-kernel distributed radix topk indexer.
Modifications
Add Topk JIT implementation (currently using torch jit, later we will adapt the code to TVM FFI interface).
Accuracy Tests
Passed. (< 5)
Speed Tests and Profiling
50% latency reduced !
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci