[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode#37421
Conversation
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
There was a problem hiding this comment.
Code Review
The pull request introduces a unified persistent TopK scheduler for DSA, which is a significant improvement for CUDAGraph safety and host-side simplification. The new approach dynamically dispatches to different TopK kernel variants based on sequence length, replacing the previous specialized kernel selection. This refactoring streamlines the TopK implementation and its integration within the system. The changes are comprehensive, spanning CUDA kernel implementations, Python bindings, and integration into the model executor and attention backend. Several hardcoded constants are introduced in the new CUDA files; while constexpr is beneficial for performance, it's important to ensure these values are well-justified and documented. The addition of RADIX_TOPK_WORKSPACE_SIZE as a named constant in Python is a good step towards improving readability and maintainability.
csrc/persistent_topk_medium.cuh
Outdated
| int* __restrict__ output_indices, | ||
| int logits_offset, | ||
| int seq_len) { | ||
| alignas(128) __shared__ int shared_histogram[2][RADIX + 128]; |
csrc/persistent_topk_medium.cuh
Outdated
| int seq_len) { | ||
| alignas(128) __shared__ int shared_histogram[2][RADIX + 128]; | ||
| alignas(128) __shared__ int shared_output_count; | ||
| alignas(128) __shared__ int shared_threshold_bin; |
csrc/topk.cuh
Outdated
| // Returns 1, 2, 4, or 8 | ||
| template <typename DType> | ||
| constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { | ||
| constexpr int MAX_VEC = 16 / sizeof(DType); // 4 for float32, 8 for fp16/bf16 |
tests/kernels/test_top_k_per_row.py
Outdated
| lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten() | ||
|
|
||
| if kernel_name == "large_context_topk": | ||
| workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") |
| max_non_topk = non_topk_vals.max() | ||
|
|
||
| # Allow small tolerance for floating point errors | ||
| assert min_cuda_val >= max_non_topk - 1e-4, ( |
| assert torch.allclose( | ||
| cuda_vals.sort(descending=True)[0], | ||
| torch_vals.sort(descending=True)[0], | ||
| rtol=1e-4, | ||
| atol=1e-4, |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
LucasWilkinson
left a comment
There was a problem hiding this comment.
This is really awesome! Thanks for all the hard work!
one nit: instead of threading the topk_workspace through the whole model definition can we just use current_workspace_manager().get_simultaneous(...)?
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Summary
Redesigns the persistent TopK kernel used by DSA as a true persistent scheduler with dynamic per-row path selection.
This supersedes and closes #34265, which took a CUDAGraph-specialization approach. Instead, this PR follows a persistent scheduler pattern where a single fixed-grid kernel dynamically dispatches each row to the appropriate path at runtime.
Problem
As #34265 demonstrated, there are four different topK-per-row kernel variants, each optimal for a different sequence length regime. This isn't an implementation artifact — it reflects fundamental algorithmic trade-offs:
Since max_seq_len changes at runtime (batches mix short decode sequences with long-context prefills), the initial approach in #34265 handled kernel selection via CUDAGraph specialization. However, this added complexity on the host side and required multiple graph variants. This PR simplifies the problem with a persistent scheduler that handles dispatch on-the-fly inside a single kernel.
Approach
Single persistent kernel, fixed grid, dynamic dispatch:
This is CUDAGraph-safe by construction: the grid shape never changes, and the captured kernel handles all sequence lengths.
Microbenchmarking
E2E results (NVIDIA B200)
Example
MAIN
PR
i.e., the current PR improves MAIN throughput by up to ~17%
Examples DP4
Accuracy