Skip to content

[Perf][Kernel] BF16 input support for persistent topK - DeepSeekV4#40811

Draft
LopezCastroRoberto wants to merge 4 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/bf16_topK
Draft

[Perf][Kernel] BF16 input support for persistent topK - DeepSeekV4#40811
LopezCastroRoberto wants to merge 4 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/bf16_topK

Conversation

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Apr 24, 2026

Motivation

We further quantize the index scores 𝐼:,: from FP32 to BF16 during this QAT process. This optimization achieves a 2× speedup for the top-k selector, while preserving a 99.7% recall rate of KV entries.

source: https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/DeepSeek_V4.pdf

depends on: #40760

Microbenchmarks

Speedups vs FP32 topK

k=512 (V4-Flash)

  ┌───────┬──────┬──────┬──────┬──────┬──────┬──────┐
  │       │  8K  │ 16K  │ 32K  │ 64K  │ 128K │ 250K │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=1  │ 1.23 │ 1.14 │ 1.08 │ 1.41 │ 1.43 │ 1.44 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=8  │ 1.18 │ 1.09 │ 1.05 │ 1.40 │ 1.40 │ 1.37 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=32 │ 1.25 │ 1.08 │ 1.19 │ 1.37 │ 1.36 │ 1.92 │
  └───────┴──────┴──────┴──────┴──────┴──────┴──────┘

k=1024 (V4-Pro)

  ┌───────┬──────┬──────┬──────┬──────┬──────┬──────┐
  │       │  8K  │ 16K  │ 32K  │ 64K  │ 128K │ 250K │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=1  │ 1.15 │ 1.13 │ 1.07 │ 1.39 │ 1.42 │ 1.40 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=8  │ 1.25 │ 1.16 │ 1.06 │ 1.37 │ 1.39 │ 1.38 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=32 │ 1.23 │ 1.10 │ 1.19 │ 1.33 │ 1.35 │ 1.88 │
  └───────┴──────┴──────┴──────┴──────┴──────┴──────┘

k=2048 (V3.2)

  ┌───────┬──────┬──────┬──────┬──────┬──────┬──────┐
  │       │  8K  │ 16K  │ 32K  │ 64K  │ 128K │ 250K │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=1  │ 1.16 │ 1.18 │ 1.03 │ 1.35 │ 1.41 │ 1.44 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=8  │ 1.15 │ 1.07 │ 1.10 │ 1.38 │ 1.36 │ 1.39 │
  ├───────┼──────┼──────┼──────┼──────┼──────┼──────┤
  │ BS=32 │ 1.25 │ 1.15 │ 1.17 │ 1.32 │ 1.31 │ 1.84 │
  └───────┴──────┴──────┴──────┴──────┴──────┴──────┘

E2E

VLLM_ENGINE_READY_TIMEOUT_S=3600 vllm serve deepseek-ai/DeepSeek-V4-Flash -tp 4 --port 8000 --kv-cache-dtype fp8

  vllm bench serve \
    --model deepseek-ai/DeepSeek-V4-Flash \
    --host localhost --port 8000 \
    --dataset-name random \
    --random-input-len 1024000 \
    --random-output-len 4096 \
    --num-prompts 8 \
    --max-concurrency 1
  ┌──────────┬───────────┬
  │  Config  │ TPOT (ms) │
  ├──────────┼───────────┼
  │ Upstream │   9.17    │
  ├──────────┼───────────┼
  │ PR BF16  │   8.95    │
  ├──────────┼───────────┼
  │ Speedup  │   1.02x   │
  └──────────┴───────────┴

2% TPOP improvement

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft April 24, 2026 14:08
@mergify mergify Bot added the deepseek Related to DeepSeek models label Apr 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends the persistent top-k kernel to support bfloat16 inputs, enabling its use in DeepSeek V3/V4 sparse attention indexers. The implementation introduces data type traits to handle ordered key conversions generically and adds specialized vectorized load helpers for BF16. A key addition is an overflow fallback mechanism for the medium-sized histogram path to handle cases where coarse-grained bins exceed shared memory capacity. Review feedback highlights opportunities to optimize the BF16 data path by utilizing 128-bit loads (VEC_SIZE=8) to match the bandwidth of the float32 implementation and suggests enabling the FilteredTopK path for BF16 now that the required traits are available.

Comment thread csrc/persistent_topk.cuh Outdated
Comment on lines +946 to +954
if constexpr (VEC_SIZE >= 4) {
// 128-bit load: 8 bf16 values
// But VEC_SIZE is in elements; load VEC_SIZE bf16s
__nv_bfloat16 v0, v1, v2, v3;
load_bf16x4(src, v0, v1, v2, v3);
shared_u16[i] = Traits::to_ordered(v0);
shared_u16[i + 1] = Traits::to_ordered(v1);
shared_u16[i + 2] = Traits::to_ordered(v2);
shared_u16[i + 3] = Traits::to_ordered(v3);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The BF16 path in radix_topk is suboptimal and potentially incorrect for VEC_SIZE > 4. It uses load_bf16x4 (64-bit) even when the comment suggests 128-bit loads are intended. To match the performance of the float path (which uses 128-bit float4 loads), BF16 should use VEC_SIZE=8 and load_bf16x8. Additionally, the current implementation only loads 4 elements even if VEC_SIZE is 8, which would lead to incorrect results.

        if constexpr (VEC_SIZE == 8) {
          // 128-bit load: 8 bf16 values
          __nv_bfloat16 v0, v1, v2, v3, v4, v5, v6, v7;
          load_bf16x8(src, v0, v1, v2, v3, v4, v5, v6, v7);
          shared_u16[i] = Traits::to_ordered(v0);
          shared_u16[i + 1] = Traits::to_ordered(v1);
          shared_u16[i + 2] = Traits::to_ordered(v2);
          shared_u16[i + 3] = Traits::to_ordered(v3);
          shared_u16[i + 4] = Traits::to_ordered(v4);
          shared_u16[i + 5] = Traits::to_ordered(v5);
          shared_u16[i + 6] = Traits::to_ordered(v6);
          shared_u16[i + 7] = Traits::to_ordered(v7);
        } else if constexpr (VEC_SIZE == 4) {
          // 64-bit load: 4 bf16 values
          __nv_bfloat16 v0, v1, v2, v3;
          load_bf16x4(src, v0, v1, v2, v3);
          shared_u16[i] = Traits::to_ordered(v0);
          shared_u16[i + 1] = Traits::to_ordered(v1);
          shared_u16[i + 2] = Traits::to_ordered(v2);
          shared_u16[i + 3] = Traits::to_ordered(v3);
        }

Comment thread csrc/topk.cu Outdated
ensure_device_props();

constexpr bool is_bf16 = std::is_same_v<scalar_t, __nv_bfloat16>;
if (!is_bf16 && num_rows > 32 && g_max_smem_per_block >= 128 * 1024) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

FilteredTopKRaggedTransform is explicitly disabled for BF16 (!is_bf16), yet this PR adds BF16 support to it in persistent_topk.cuh (via FilteredTopKTraits<__nv_bfloat16> and overflow fallback logic). This kernel is often more efficient than the persistent one for large batches of rows. If it's implemented and tested, it should be enabled for BF16 as well.

  if (num_rows > 32 && g_max_smem_per_block >= 128 * 1024) {

Comment thread csrc/topk.cu
Comment on lines +70 to +74
uint32_t vec_size = 1;
if (stride % 4 == 0)
vec_size = 4;
else if (stride % 2 == 0)
vec_size = 2;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For BF16, we should attempt to use vec_size = 8 to achieve 128-bit load bandwidth, matching the float4 path. The current logic caps vec_size at 4 regardless of dtype, which limits BF16 to 64-bit loads in the large path.

  uint32_t vec_size = 1;
  uint32_t max_vec = is_bf16 ? 8 : 4;
  if (stride % max_vec == 0)
    vec_size = max_vec;
  else if (stride % 2 == 0)
    vec_size = 2;

Comment thread csrc/topk.cu
Comment on lines +129 to +135
if (vec_size == 4) {
LAUNCH_PERSISTENT(scalar_t, 4);
} else if (vec_size == 2) {
LAUNCH_PERSISTENT(scalar_t, 2);
} else {
LAUNCH_PERSISTENT(scalar_t, 1);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dispatch logic is missing the vec_size == 8 case, which is necessary if the vec_size calculation is updated to support 128-bit loads for BF16.

  if (vec_size == 8) {
    LAUNCH_PERSISTENT(scalar_t, 8);
  } else if (vec_size == 4) {
    LAUNCH_PERSISTENT(scalar_t, 4);
  } else if (vec_size == 2) {
    LAUNCH_PERSISTENT(scalar_t, 2);
  } else {
    LAUNCH_PERSISTENT(scalar_t, 1);
  }

@zyongye
Copy link
Copy Markdown
Member

zyongye commented Apr 24, 2026

Is the output from *_mqa_logits already bf16?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LopezCastroRoberto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 27, 2026
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto
Copy link
Copy Markdown
Contributor Author

Is the output from *_mqa_logits already bf16?

@zyongye it is FP32 by default, but the method accepts an logits_dtype to control this

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LopezCastroRoberto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants