Revert "[Kernel] Fuse temperature + softmax in sampling for decode speedup"#22046
Revert "[Kernel] Fuse temperature + softmax in sampling for decode speedup"#22046
Conversation
…eedup (#…" This reverts commit 7a59e05.
There was a problem hiding this comment.
Code Review
This pull request removes the fused Triton temperature-softmax kernel and its associated warmup logic, reverting to standard PyTorch operations for temperature scaling and softmax. In the sampler implementation, the code was updated to perform these operations using standard PyTorch calls. I have provided a suggestion to optimize the memory usage by avoiding a redundant copy operation during the softmax calculation.
| # In-place op to save memory | ||
| logits[:] = torch.softmax(logits, dim=-1) | ||
| probs = logits |
There was a problem hiding this comment.
The comment # In-place op to save memory is misleading because torch.softmax is not an in-place operation and still requires a temporary allocation of the same size as the input. Additionally, logits[:] = ... performs a redundant data copy into the existing buffer.
To improve efficiency, you can avoid the copy by assigning the result of softmax directly to probs and updating the reference in logits_output.next_token_logits. This avoids the element-wise copy while still ensuring the output object contains the probabilities.
| # In-place op to save memory | |
| logits[:] = torch.softmax(logits, dim=-1) | |
| probs = logits | |
| probs = torch.softmax(logits, dim=-1) | |
| logits_output.next_token_logits = probs |
Reverts #20501