[Perf][Kernel] BF16 input support for persistent topK - DeepSeekV4#40811
[Perf][Kernel] BF16 input support for persistent topK - DeepSeekV4#40811LopezCastroRoberto wants to merge 4 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request extends the persistent top-k kernel to support bfloat16 inputs, enabling its use in DeepSeek V3/V4 sparse attention indexers. The implementation introduces data type traits to handle ordered key conversions generically and adds specialized vectorized load helpers for BF16. A key addition is an overflow fallback mechanism for the medium-sized histogram path to handle cases where coarse-grained bins exceed shared memory capacity. Review feedback highlights opportunities to optimize the BF16 data path by utilizing 128-bit loads (VEC_SIZE=8) to match the bandwidth of the float32 implementation and suggests enabling the FilteredTopK path for BF16 now that the required traits are available.
| if constexpr (VEC_SIZE >= 4) { | ||
| // 128-bit load: 8 bf16 values | ||
| // But VEC_SIZE is in elements; load VEC_SIZE bf16s | ||
| __nv_bfloat16 v0, v1, v2, v3; | ||
| load_bf16x4(src, v0, v1, v2, v3); | ||
| shared_u16[i] = Traits::to_ordered(v0); | ||
| shared_u16[i + 1] = Traits::to_ordered(v1); | ||
| shared_u16[i + 2] = Traits::to_ordered(v2); | ||
| shared_u16[i + 3] = Traits::to_ordered(v3); |
There was a problem hiding this comment.
The BF16 path in radix_topk is suboptimal and potentially incorrect for VEC_SIZE > 4. It uses load_bf16x4 (64-bit) even when the comment suggests 128-bit loads are intended. To match the performance of the float path (which uses 128-bit float4 loads), BF16 should use VEC_SIZE=8 and load_bf16x8. Additionally, the current implementation only loads 4 elements even if VEC_SIZE is 8, which would lead to incorrect results.
if constexpr (VEC_SIZE == 8) {
// 128-bit load: 8 bf16 values
__nv_bfloat16 v0, v1, v2, v3, v4, v5, v6, v7;
load_bf16x8(src, v0, v1, v2, v3, v4, v5, v6, v7);
shared_u16[i] = Traits::to_ordered(v0);
shared_u16[i + 1] = Traits::to_ordered(v1);
shared_u16[i + 2] = Traits::to_ordered(v2);
shared_u16[i + 3] = Traits::to_ordered(v3);
shared_u16[i + 4] = Traits::to_ordered(v4);
shared_u16[i + 5] = Traits::to_ordered(v5);
shared_u16[i + 6] = Traits::to_ordered(v6);
shared_u16[i + 7] = Traits::to_ordered(v7);
} else if constexpr (VEC_SIZE == 4) {
// 64-bit load: 4 bf16 values
__nv_bfloat16 v0, v1, v2, v3;
load_bf16x4(src, v0, v1, v2, v3);
shared_u16[i] = Traits::to_ordered(v0);
shared_u16[i + 1] = Traits::to_ordered(v1);
shared_u16[i + 2] = Traits::to_ordered(v2);
shared_u16[i + 3] = Traits::to_ordered(v3);
}
| ensure_device_props(); | ||
|
|
||
| constexpr bool is_bf16 = std::is_same_v<scalar_t, __nv_bfloat16>; | ||
| if (!is_bf16 && num_rows > 32 && g_max_smem_per_block >= 128 * 1024) { |
There was a problem hiding this comment.
FilteredTopKRaggedTransform is explicitly disabled for BF16 (!is_bf16), yet this PR adds BF16 support to it in persistent_topk.cuh (via FilteredTopKTraits<__nv_bfloat16> and overflow fallback logic). This kernel is often more efficient than the persistent one for large batches of rows. If it's implemented and tested, it should be enabled for BF16 as well.
if (num_rows > 32 && g_max_smem_per_block >= 128 * 1024) {
| uint32_t vec_size = 1; | ||
| if (stride % 4 == 0) | ||
| vec_size = 4; | ||
| else if (stride % 2 == 0) | ||
| vec_size = 2; |
There was a problem hiding this comment.
For BF16, we should attempt to use vec_size = 8 to achieve 128-bit load bandwidth, matching the float4 path. The current logic caps vec_size at 4 regardless of dtype, which limits BF16 to 64-bit loads in the large path.
uint32_t vec_size = 1;
uint32_t max_vec = is_bf16 ? 8 : 4;
if (stride % max_vec == 0)
vec_size = max_vec;
else if (stride % 2 == 0)
vec_size = 2;
| if (vec_size == 4) { | ||
| LAUNCH_PERSISTENT(scalar_t, 4); | ||
| } else if (vec_size == 2) { | ||
| LAUNCH_PERSISTENT(scalar_t, 2); | ||
| } else { | ||
| LAUNCH_PERSISTENT(scalar_t, 1); | ||
| } |
There was a problem hiding this comment.
The dispatch logic is missing the vec_size == 8 case, which is necessary if the vec_size calculation is updated to support 128-bit loads for BF16.
if (vec_size == 8) {
LAUNCH_PERSISTENT(scalar_t, 8);
} else if (vec_size == 4) {
LAUNCH_PERSISTENT(scalar_t, 4);
} else if (vec_size == 2) {
LAUNCH_PERSISTENT(scalar_t, 2);
} else {
LAUNCH_PERSISTENT(scalar_t, 1);
}
eef880e to
8bef97f
Compare
|
Is the output from |
|
This pull request has merge conflicts that must be resolved before it can be |
04fe1cd to
ccdb2f3
Compare
@zyongye it is FP32 by default, but the method accepts an |
|
This pull request has merge conflicts that must be resolved before it can be |
Motivation
source: https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/DeepSeek_V4.pdf
depends on: #40760
Microbenchmarks
Speedups vs FP32 topK
E2E
VLLM_ENGINE_READY_TIMEOUT_S=3600 vllm serve deepseek-ai/DeepSeek-V4-Flash -tp 4 --port 8000 --kv-cache-dtype fp82% TPOP improvement