-
Notifications
You must be signed in to change notification settings - Fork 205
[2/2] Top-k and Top-p support for dtensor worker with vLLM V0 when TP>1
#774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: zhanda/top-p-k
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for top-k and top-p sampling in the dtensor policy worker with vLLM V0 when tensor parallelism (TP) is greater than 1. The implementation introduces a new distributed log softmax function that handles sampling parameters and modifies existing functions to propagate these parameters through the call stack.
- Implements
_compute_distributed_log_softmax_with_samplingto handle top-k/top-p sampling in distributed environments - Adds sampling parameter extraction and propagation through the dtensor policy worker
- Updates function signatures across the distributed model utilities to support sampling parameters
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| nemo_rl/models/policy/dtensor_policy_worker.py | Adds sampling parameter extraction and passes them to logprob computation functions |
| nemo_rl/models/dtensor/parallelize.py | Updates function signature to accept and forward sampling parameters |
| nemo_rl/distributed/model_utils.py | Implements new sampling-aware distributed log softmax and updates all related functions |
| Returns: | ||
| Log softmax output with sampling applied, same shape as input | ||
| """ | ||
| if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0): |
Copilot
AI
Jul 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition uses and between two parenthesized conditions, but logically this should be or since either condition being true (top_k disabled OR top_p disabled) should trigger the fallback to regular log softmax.
| if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0): | |
| if (top_k is not None and top_k == -1) or (top_p is not None and top_p == 1.0): |
| log_softmax_output = _compute_distributed_log_softmax( | ||
| vocab_parallel_logits, group=group | ||
| ) | ||
| # Use sampling-aware distributed log softmax if sampling parameters are provided |
Copilot
AI
Jul 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition on line 142 uses or logic but the comment suggests both parameters need to be provided. The logic is correct (either parameter being active should use sampling), but the comment is misleading.
| # Use sampling-aware distributed log softmax if sampling parameters are provided | |
| # Use sampling-aware distributed log softmax if either top_k or top_p is provided |
| Args: | ||
| vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers, | ||
| vocab_parallel_logits (torch.Tensor): Logits distributed across tensor parallel workers, |
Copilot
AI
Jul 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring has a typo - 'orch.Tensor' was partially corrected to 'torch.Tensor' but the diff shows this was already fixed.
What does this PR do ?
tldr: Support top-k and top-p for dtensor worker with vLLM v0. This pr supports
tp>1on top of #773.Instead of using
_compute_distributed_log_softmax, we implement_compute_distributed_log_softmax_with_samplingto support Top-k and Top-p when TP is enabled.Note: This change depends on #773 and should be merged after it. We should also decide if we want to merge this implementation or, alternatively, add a warning to users about a potential mismatch between this inference logic and the logic used in policy training for vLLM engine V0 and dtensor with TP>1.
Tests for distributed functionalities and docs will be added after we make the decision.
Issues
Related Issue: #69
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks: