Skip to content

[Do not merge][Test] Revert "[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2"#30715

Closed
LucasWilkinson wants to merge 1 commit intomainfrom
revert-27532-lwilkinson/upconvert-all-2
Closed

[Do not merge][Test] Revert "[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2"#30715
LucasWilkinson wants to merge 1 commit intomainfrom
revert-27532-lwilkinson/upconvert-all-2

Conversation

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

Reverts #27532

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 15, 2025
@mergify mergify bot added deepseek Related to DeepSeek models gpt-oss Related to GPT-OSS models nvidia labels Dec 15, 2025
@mergify mergify bot added the v1 label Dec 15, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reverts the introduction of the sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2. The changes primarily involve removing the cp_gather_and_upconvert_fp8_kv_cache kernel, its associated logic, and the global WorkspaceManager. The WorkspaceManager is replaced by a more encapsulated SharedResizableBuffer within the FusedMoEModularKernel. While the revert is largely correct, I've identified a critical issue with the new buffer implementation that could break CUDA graph compatibility due to the lack of a locking mechanism. Additionally, there's a performance concern in deepseek_v2.py where repeated memory allocations have been introduced inside a loop.

Comment on lines +666 to +682
class SharedResizableBuffer:
def __init__(self):
self.buffer = None

def get(
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
assert shape != ()
shape_numel = prod(shape)
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new SharedResizableBuffer implementation lacks a locking mechanism, which was present in the removed WorkspaceManager. The WorkspaceManager was locked after profiling/warmup runs to prevent the workspace from resizing during execution, which is crucial for CUDA graph stability. Without a lock, if a request requiring a larger buffer arrives after a CUDA graph has been captured, SharedResizableBuffer.get() will reallocate the buffer, breaking the captured graph. This can lead to runtime errors or incorrect computations. A locking mechanism should be re-introduced to prevent buffer reallocation after CUDA graph capture.

Comment on lines 654 to +664
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change from using a pre-allocated workspace to calling torch.empty inside the loop for each prefill chunk introduces repeated memory allocations. This can lead to performance degradation and memory fragmentation, especially when processing many small chunks. The previous approach of allocating a single large buffer and slicing it for each chunk is generally more efficient. Consider reverting to a similar pattern of pre-allocating a buffer outside the loop to avoid repeated allocations in the hot path.

@LucasWilkinson
Copy link
Copy Markdown
Collaborator Author

resolved by: #30744

@github-project-automation github-project-automation bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Dec 17, 2025
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Dec 17, 2025
@LucasWilkinson LucasWilkinson deleted the revert-27532-lwilkinson/upconvert-all-2 branch January 19, 2026 16:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models gpt-oss Related to GPT-OSS models nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

1 participant