Skip to content

Revert "[Kernel] Fuse temperature + softmax in sampling for decode speedup"#22046

Merged
BBuf merged 1 commit intomainfrom
revert-20501-fused_sampling
Apr 3, 2026
Merged

Revert "[Kernel] Fuse temperature + softmax in sampling for decode speedup"#22046
BBuf merged 1 commit intomainfrom
revert-20501-fused_sampling

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Apr 3, 2026

Reverts #20501

@BBuf BBuf merged commit ee9d922 into main Apr 3, 2026
58 of 67 checks passed
@BBuf BBuf deleted the revert-20501-fused_sampling branch April 3, 2026 13:32
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the fused Triton temperature-softmax kernel and its associated warmup logic, reverting to standard PyTorch operations for temperature scaling and softmax. In the sampler implementation, the code was updated to perform these operations using standard PyTorch calls. I have provided a suggestion to optimize the memory usage by avoiding a redundant copy operation during the softmax calculation.

Comment on lines +157 to +159
# In-place op to save memory
logits[:] = torch.softmax(logits, dim=-1)
probs = logits
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment # In-place op to save memory is misleading because torch.softmax is not an in-place operation and still requires a temporary allocation of the same size as the input. Additionally, logits[:] = ... performs a redundant data copy into the existing buffer.

To improve efficiency, you can avoid the copy by assigning the result of softmax directly to probs and updating the reference in logits_output.next_token_logits. This avoids the element-wise copy while still ensuring the output object contains the probabilities.

Suggested change
# In-place op to save memory
logits[:] = torch.softmax(logits, dim=-1)
probs = logits
probs = torch.softmax(logits, dim=-1)
logits_output.next_token_logits = probs

JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
xiezhq-hermann pushed a commit to antgroup/sglang that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant