Skip to content

Conversation

@zhandaz
Copy link
Contributor

@zhandaz zhandaz commented Jul 28, 2025

What does this PR do ?

tldr: Support top-k and top-p for dtensor worker with vLLM v0. This pr supports tp>1 on top of #773.

Instead of using _compute_distributed_log_softmax, we implement _compute_distributed_log_softmax_with_sampling to support Top-k and Top-p when TP is enabled.

Note: This change depends on #773 and should be merged after it. We should also decide if we want to merge this implementation or, alternatively, add a warning to users about a potential mismatch between this inference logic and the logic used in policy training for vLLM engine V0 and dtensor with TP>1.

Tests for distributed functionalities and docs will be added after we make the decision.

Issues

Related Issue: #69

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

@zhandaz zhandaz requested a review from Copilot July 28, 2025 21:23
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for top-k and top-p sampling in the dtensor policy worker with vLLM V0 when tensor parallelism (TP) is greater than 1. The implementation introduces a new distributed log softmax function that handles sampling parameters and modifies existing functions to propagate these parameters through the call stack.

  • Implements _compute_distributed_log_softmax_with_sampling to handle top-k/top-p sampling in distributed environments
  • Adds sampling parameter extraction and propagation through the dtensor policy worker
  • Updates function signatures across the distributed model utilities to support sampling parameters

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
nemo_rl/models/policy/dtensor_policy_worker.py Adds sampling parameter extraction and passes them to logprob computation functions
nemo_rl/models/dtensor/parallelize.py Updates function signature to accept and forward sampling parameters
nemo_rl/distributed/model_utils.py Implements new sampling-aware distributed log softmax and updates all related functions

Returns:
Log softmax output with sampling applied, same shape as input
"""
if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0):
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition uses and between two parenthesized conditions, but logically this should be or since either condition being true (top_k disabled OR top_p disabled) should trigger the fallback to regular log softmax.

Suggested change
if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0):
if (top_k is not None and top_k == -1) or (top_p is not None and top_p == 1.0):

Copilot uses AI. Check for mistakes.
log_softmax_output = _compute_distributed_log_softmax(
vocab_parallel_logits, group=group
)
# Use sampling-aware distributed log softmax if sampling parameters are provided
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition on line 142 uses or logic but the comment suggests both parameters need to be provided. The logic is correct (either parameter being active should use sampling), but the comment is misleading.

Suggested change
# Use sampling-aware distributed log softmax if sampling parameters are provided
# Use sampling-aware distributed log softmax if either top_k or top_p is provided

Copilot uses AI. Check for mistakes.
Args:
vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers,
vocab_parallel_logits (torch.Tensor): Logits distributed across tensor parallel workers,
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring has a typo - 'orch.Tensor' was partially corrected to 'torch.Tensor' but the diff shows this was already fixed.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Top-p/Top-k Sampling Params handling in VLLM v1

2 participants