Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ReferenceLogprobOutputSpec,
)
from nemo_rl.models.policy.utils import (
apply_top_k_top_p,
configure_expandable_segments,
get_gpu_info,
get_runtime_env_for_policy_worker,
Expand Down Expand Up @@ -417,19 +418,40 @@ def create_context_parallel_ctx(
no_restore_buffers=cp_no_restore_buffers,
)

# Refer to nemo impl. Below is original comment.
# based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178

def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
# Apply temperature scaling to logits if configured and not using V1 engine.
def _post_processing_logits_with_sampling_params(
self, logits: torch.Tensor
) -> torch.Tensor:
# Apply temperature scaling and top-k/top-p sampling to logits if configured and not using V1 engine.
if "generation" in self.cfg and self.cfg["generation"] is not None:
# The V1 engine returns raw logits before temperature scaling.
# The V0 engine returns scaled logits.
# Therefore, we only divide if we are not using the V1 engine.
if not is_vllm_v1_engine_enabled():
# Convert to float32 to mimic vllm's behavior.
logits = logits.to(torch.float32)

# Apply temperature scaling.
logits.div_(self.cfg["generation"]["temperature"])

# Apply top-k and top-p sampling
top_k = self.cfg["generation"]["top_k"]
top_p = self.cfg["generation"]["top_p"]
# Skip if no sampling is configured
if (top_k is None or top_k == -1) and (top_p is None or top_p == 1.0):
return logits

if self.tp_size == 1:
logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p)
else:
raise ValueError(
"Top-k and top-p sampling is not supported for tensor_parallel_size > 1 for vLLM v0 engine. "
"Please use vLLM v1 engine instead."
)

return logits

# Refer to nemo impl. Below is original comment.
# based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178
@staticmethod
@contextlib.contextmanager
def train_context(cp_context: Optional[Generator[None, None, None]] = None):
Expand Down Expand Up @@ -666,8 +688,10 @@ def train(
else:
logits = outputs.logits

# Apply temperature scaling
logits = self._apply_temperature_scaling(logits)
# Apply temperature scaling and sampling parameters
logits = self._post_processing_logits_with_sampling_params(
logits
)

if self.cp_size > 1:
seq_index_dtensor = (
Expand Down Expand Up @@ -947,8 +971,8 @@ def get_logprobs(

logits = outputs.logits

# Apply temperature scaling
logits = self._apply_temperature_scaling(logits)
# Apply temperature scaling and sampling parameters
logits = self._post_processing_logits_with_sampling_params(logits)

if self.cp_size > 1:
seq_index_tensor = (
Expand Down
96 changes: 95 additions & 1 deletion nemo_rl/models/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,108 @@

import importlib
import os
from typing import Any
from typing import Any, Optional

import torch
from transformers import AutoConfig

from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches


def apply_top_k_top_p(
logits: torch.Tensor,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
Simplified version of VLLM's implementation for scalar parameters.
Based on VLLM's implementation:
https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py
Args:
logits: Input logits tensor of shape [batch_size, seq_len, vocab_size]
top_k: Top-k sampling parameter.
top_p: Top-p (nucleus) sampling parameter.
Returns:
Filtered logits with sampling parameters applied
"""
if top_p is None or top_p == 1.0:
if top_k is None or top_k == -1:
return logits
# Avoid sorting vocab for top-k only case
return apply_top_k_only(logits, top_k)

# Apply top-p (requires sorting)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

if top_k is not None and top_k > 0:
# Apply top-k first
top_k_index = logits_sort.size(-1) - top_k
# Get all the top_k values - need to broadcast the index across all dimensions
index_tensor = torch.full(
logits_sort.shape[:-1],
top_k_index,
device=logits_sort.device,
dtype=torch.long,
)
top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1))
top_k_mask = logits_sort < top_k_threshold
logits_sort.masked_fill_(top_k_mask, -float("inf"))

# Apply top-p
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_p_mask = probs_sum <= 1 - top_p
# at least one - but for p=0.0, we want exactly one
if top_p == 0.0:
# Keep only the highest probability token
top_p_mask = torch.ones_like(top_p_mask, dtype=torch.bool)
top_p_mask[..., 0] = False # Keep only the first (highest prob) token
else:
# at least one
top_p_mask[..., -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))

# Re-sort the probabilities
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits


def apply_top_k_only(
logits: torch.Tensor,
top_k: int,
) -> torch.Tensor:
"""Apply top-k mask to the logits.
Simplified version of VLLM's implementation for scalar parameters.
This implementation doesn't involve sorting the entire vocab.
Args:
logits: Input logits tensor of shape [batch_size, seq_len, vocab_size]
top_k: Top-k sampling parameter.
Returns:
Filtered logits with top-k applied
"""
if top_k >= logits.shape[-1] or top_k == -1:
return logits

# Get top-k values and create mask
top_k_values, _ = torch.topk(logits, top_k, dim=-1)
threshold = top_k_values[..., -1:].expand_as(logits)
mask = logits >= threshold

# Apply mask: keep top-k values, set others to -inf
logits = torch.where(
mask,
logits,
torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype),
)
return logits


def is_vllm_v1_engine_enabled() -> bool:
"""Check if vLLM V1 engine is enabled.
Expand Down
152 changes: 152 additions & 0 deletions tests/unit/models/policy/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,164 @@
import unittest.mock
from unittest.mock import MagicMock, patch

import torch
import torch.nn.functional as F

from nemo_rl.models.policy.utils import (
apply_top_k_top_p,
configure_expandable_segments,
get_megatron_checkpoint_dir,
)


def manual_top_k(logits, k):
"""Manual reference implementation for top-k"""
top_k_values, _ = torch.topk(logits, k, dim=-1)
threshold = top_k_values[..., -1:].expand_as(logits)
mask = logits >= threshold
return torch.where(
mask,
logits,
torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype),
)


def manual_top_p(logits, p):
"""Manual reference implementation for top-p"""
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumulative_probs <= p
# Always keep at least one
mask[..., 0] = True
keep_indices = torch.zeros_like(logits, dtype=torch.bool)
for b in range(logits.shape[0]):
for s in range(logits.shape[1]):
keep = sorted_indices[b, s][mask[b, s]]
keep_indices[b, s, keep] = True
return torch.where(
keep_indices,
logits,
torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype),
)


class TestSamplingUtils(unittest.TestCase):
"""Test cases for sampling utility functions."""

def test_top_k_only(self):
"""Test top-k only sampling"""
torch.manual_seed(0)
logits = torch.randn(2, 3, 10)
k = 4
result = apply_top_k_top_p(logits.clone(), top_k=k)
expected = manual_top_k(logits, k)
self.assertTrue(torch.allclose(result, expected))
# Check only k values per row are not -inf
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf == k))

def test_top_p_only(self):
"""Test top-p only sampling"""
torch.manual_seed(1)
logits = torch.randn(2, 3, 10)
p = 0.7
result = apply_top_k_top_p(logits.clone(), top_p=p)

# Check that we have at least one non-inf value per position
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf >= 1))

# Check that probability sums to 1.0
result_probs = F.softmax(result, dim=-1)
result_sum = result_probs.sum(dim=-1)
self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum)))

def test_top_k_and_top_p(self):
"""Test both top-k and top-p sampling"""
torch.manual_seed(2)
logits = torch.randn(2, 3, 10)
k = 5
p = 0.8
result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p)

# Check basic properties
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf >= 1))
self.assertTrue(torch.all(non_inf <= k))

# Check that probability sums to 1.0
result_probs = F.softmax(result, dim=-1)
result_sum = result_probs.sum(dim=-1)
self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum)))

def test_no_sampling(self):
"""Test no sampling (should return original logits)"""
torch.manual_seed(3)
logits = torch.randn(2, 3, 10)
result = apply_top_k_top_p(logits.clone())
self.assertTrue(torch.allclose(result, logits))

def test_edge_cases(self):
"""Test edge cases for sampling parameters"""
torch.manual_seed(4)
logits = torch.randn(2, 3, 10)

# k = vocab size (should return original logits)
result = apply_top_k_top_p(logits.clone(), top_k=10)
self.assertTrue(torch.allclose(result, logits))

# k = 1 (should keep only one token)
result = apply_top_k_top_p(logits.clone(), top_k=1)
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf == 1))

# p = 1.0 (should return original logits)
result = apply_top_k_top_p(logits.clone(), top_p=1.0)
self.assertTrue(torch.allclose(result, logits))

# p = 0.0 (should keep only the highest probability token)
result = apply_top_k_top_p(logits.clone(), top_p=0.0)
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf == 1))

# k > vocab size (should return original logits)
result = apply_top_k_top_p(logits.clone(), top_k=20)
self.assertTrue(torch.allclose(result, logits))

def test_gradient_flow(self):
"""Test that gradients flow through the sampling operations"""
torch.manual_seed(5)
logits = torch.randn(2, 3, 10, requires_grad=True)
result = apply_top_k_top_p(logits, top_k=3, top_p=0.7)
loss = result.sum()
loss.backward()

# Check that gradients exist and are non-zero
self.assertIsNotNone(logits.grad)
self.assertEqual(logits.grad.shape, logits.shape)
self.assertTrue(torch.any(logits.grad != 0))

def test_large_vocab(self):
"""Test with larger vocabulary size"""
torch.manual_seed(6)
logits = torch.randn(1, 2, 1000) # Large vocab
k = 50
p = 0.9

result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p)

# Check basic properties
non_inf = (result != -float("inf")).sum(dim=-1)
self.assertTrue(torch.all(non_inf >= 1))
self.assertTrue(torch.all(non_inf <= k))

# Check that probability sums to 1.0
result_probs = F.softmax(result, dim=-1)
result_sum = result_probs.sum(dim=-1)
self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum)))


class TestConfigureExpandableSegments(unittest.TestCase):
"""Test cases for configure_expandable_segments function."""

Expand Down
Loading