Skip to content

[Perf & Feat] Add deepseek32 topk opt : Introduction to the ultra low latency attention#23761

Open
yiakwy-xpu-ml-framework-team wants to merge 3 commits into
sgl-project:mainfrom
yiakwy-xpu-ml-framework-team:add_deepseek32_topk_opt
Open

[Perf & Feat] Add deepseek32 topk opt : Introduction to the ultra low latency attention#23761
yiakwy-xpu-ml-framework-team wants to merge 3 commits into
sgl-project:mainfrom
yiakwy-xpu-ml-framework-team:add_deepseek32_topk_opt

Conversation

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team commented Apr 26, 2026

Motivation

DeeSeek V4's 1 M-context makes decoding stage topk introduced in DS32 a prohibitive bottlenect . We reduce the latency by :

(1) First we compute historgram in parallel to reduce collision rates per block and then accumulate via NoC network before N-ways prefix sum and prove this is effectively method to reduce latency in a throughput oriented hardware.

(2) Second, we enhance the linear mapping properties for radix sort in NSA problem for reduction iteration. Instead of traditional top 8/11/13 bits of IEEE FP32, FP16 format, we redesign a linear mapping such that $b(y) > b(x)$
, naturally deducing $x > y$. With this linear mapping design, we greatly reduced per block elements dropped in the bin to determine the residule numbers. This further facilitate cache friendly visiting over 1-M context length : we hence enable less SMEM revisiting more elements.

(3) When remainder elements reduced to 8/16, we can simply use CAS operations to performa a parallel sorting in few cycles. This further reduce the latency overhead in the last round.

See details from https://github.com/yiakwy-xpu-ml-framework-team/flash-float-jit-kernels

The work is adpation from flash-float-jit-kernel distributed radix topk indexer.

Modifications

Add Topk JIT implementation (currently using torch jit, later we will adapt the code to TVM FFI interface).

Accuracy Tests

Passed. (< 5)

截屏2026-04-26 16 42 38

Speed Tests and Profiling

50% latency reduced !

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a distributed TopK Radix Indexer implemented in CUDA to accelerate decoding for DeepSeek V3.2, including specialized kernels for ragged and paged attention layouts and a Python JIT wrapper. The review identified several critical issues in the CUDA implementation: a potential shared memory buffer overflow due to missing bounds checks when elements exceed the bin capacity, a logic error in the transform kernels where the unrolled copy loop only handles half of the required TopK indices, and hardcoded sequence length constants that limit the kernel's flexibility. Additionally, the radix conversion logic currently assumes pre-scaled input, which may lead to incorrect sorting for standard attention logit ranges.

Comment on lines +690 to +692
const unsigned int pos = ::atomicAdd(&s_num_input[0], 1);
// if (pos < SMEM_INPUT_SIZE) {
s_input[0][pos] = val_scale;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is no bounds check on pos before writing to s_input[0][pos]. If the number of elements in a single radix bin exceeds SMEM_INPUT_SIZE (2048), this will cause a shared memory buffer overflow and potentially crash the kernel or produce incorrect results.

Comment on lines +967 to +972
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The copy loop only handles idx_0 and idx_1, which corresponds to 2 * kThreadsPerBlock = 1024 elements. Since TopK is hardcoded to 2048, this loop will only copy half of the indices to the destination page table. You should use a loop or additional unrolled steps to cover the full TopK range.

    for (int i = tid; i < TopK; i += kThreadsPerBlock) {
      dst_page_entry[i] = src_page_entry[s_indices[i]];
    }


// TODO (yiakwy) : test
__device__ __forceinline__ auto convert_to_monotonic_8bit(float x) -> uint8_t {
int bin = __float2int_rd(x);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of convert_to_monotonic_8bit using __float2int_rd(x) assumes that the input scores are already scaled to the range [0, 255]. In a production environment, attention logits or probabilities are typically in a different range (e.g., (-inf, inf) or [0, 1]). If the input is not pre-scaled, this will result in all values falling into the same bin, breaking the radix sort logic. Consider using a proper bit-manipulation approach for floating-point radix sort to ensure monotonicity across the full range of float values.

Copy link
Copy Markdown
Contributor Author

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep topk data indexer should be normalized in 0.1, hence we mapping it to 0, 256 for . This dramatically reduced radix rounds from 4 to 1 or 2.

const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional<at::Tensor> row_starts_opt) {
CHECK_CUDA(score);
const auto B = score.size(0);
const auto L = 65536; // score.size(1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The sequence length L is hardcoded to 65536. This will lead to incorrect split_kv calculations and potential out-of-bounds access or performance degradation if the actual sequence length differs from this constant. It should be derived from score.size(1).

  const auto L = score.size(1);

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the title Add deepseek32 topk opt : Introduction to the ultra low latency attention [Perf & Feat] Add deepseek32 topk opt : Introduction to the ultra low latency attention Apr 27, 2026
@DarkSharpness
Copy link
Copy Markdown
Collaborator

Hi. We also implement a low-latency topk for DSv4 where k <= 1024. Could you please refer to that and maybe also port that? I believe that could be useful.

@MichoChan
Copy link
Copy Markdown

Have you compared the kernel implementations in SGLang? I see your comparison only includes Torch's native implementation.

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor Author

yiakwy-xpu-ml-framework-team commented May 1, 2026

Hi. We also implement a low-latency topk for DSv4 where k <= 1024. Could you please refer to that and maybe also port that? I believe that could be useful.

Sure, definitely, TopK 1024 could be supported.

This week we are busy in supporting DeepSeek V4 RL, @DarkSharpness Thanks for the reminding !

Ref : https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/config.json#L18

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor Author

Have you compared the kernel implementations in SGLang? I see your comparison only includes Torch's native implementation.

Yes, the column with name "radix" represents "fast_topk_v2" in sglang .

@zianglih
Copy link
Copy Markdown
Contributor

zianglih commented May 5, 2026

Hi, we can later integrate this as a backend for #22851

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor Author

Features with TVM-FFI intreface will be upstreamed with support of Top2048/1024/512

Wish your attention @hnyls2002, cc @merrymercy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants