diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 8c0198784b..705533674b 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -66,6 +66,7 @@ ReferenceLogprobOutputSpec, ) from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, configure_dynamo_cache, configure_expandable_segments, get_gpu_info, @@ -463,19 +464,40 @@ def create_context_parallel_ctx( no_restore_buffers=cp_no_restore_buffers, ) - # Refer to nemo impl. Below is original comment. - # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 - - def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: - # Apply temperature scaling to logits if configured and not using V1 engine. + def _post_processing_logits_with_sampling_params( + self, logits: torch.Tensor + ) -> torch.Tensor: + # Apply temperature scaling and top-k/top-p sampling to logits if configured and not using V1 engine. if "generation" in self.cfg and self.cfg["generation"] is not None: # The V1 engine returns raw logits before temperature scaling. # The V0 engine returns scaled logits. # Therefore, we only divide if we are not using the V1 engine. if not is_vllm_v1_engine_enabled(): + # Convert to float32 to mimic vllm's behavior. + logits = logits.to(torch.float32) + + # Apply temperature scaling. logits.div_(self.cfg["generation"]["temperature"]) + + # Apply top-k and top-p sampling + top_k = self.cfg["generation"]["top_k"] + top_p = self.cfg["generation"]["top_p"] + # Skip if no sampling is configured + if (top_k is None or top_k == -1) and (top_p is None or top_p == 1.0): + return logits + + if self.tp_size == 1: + logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) + else: + raise ValueError( + "Top-k and top-p sampling is not supported for tensor_parallel_size > 1 for vLLM v0 engine. " + "Please use vLLM v1 engine instead." + ) + return logits + # Refer to nemo impl. Below is original comment. + # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 @staticmethod @contextlib.contextmanager def train_context(cp_context: Optional[Generator[None, None, None]] = None): @@ -721,8 +743,10 @@ def train( else: logits = outputs.logits - # Apply temperature scaling - logits = self._apply_temperature_scaling(logits) + # Apply temperature scaling and sampling parameters + logits = self._post_processing_logits_with_sampling_params( + logits + ) if self.cp_size > 1: seq_index_dtensor = ( @@ -1002,8 +1026,8 @@ def get_logprobs( logits = outputs.logits - # Apply temperature scaling - logits = self._apply_temperature_scaling(logits) + # Apply temperature scaling and sampling parameters + logits = self._post_processing_logits_with_sampling_params(logits) if self.cp_size > 1: seq_index_tensor = ( diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index c09a201268..4f4766f8df 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -14,7 +14,7 @@ import importlib import os -from typing import Any +from typing import Any, Optional import torch from transformers import AutoConfig @@ -22,6 +22,94 @@ from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches +def apply_top_k_top_p( + logits: torch.Tensor, + top_k: Optional[int] = None, + top_p: Optional[float] = None, +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + Simplified version of VLLM's implementation for scalar parameters. + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + + Args: + logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + top_k: Top-k sampling parameter. Set to -1 to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + + Returns: + Filtered logits with sampling parameters applied + """ + if top_p is None or top_p == 1.0: + if top_k is None or top_k == -1: + return logits + # Avoid sorting vocab for top-k only case + return apply_top_k_only(logits, top_k) + + # Apply top-p (requires sorting) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if top_k is not None and top_k != -1: + # Apply top-k first + top_k_index = logits_sort.size(-1) - top_k + # Get all the top_k values - need to broadcast the index across all dimensions + index_tensor = torch.full( + logits_sort.shape[:-1], + top_k_index, + device=logits_sort.device, + dtype=torch.long, + ) + top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) + top_k_mask = logits_sort < top_k_threshold + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_p_mask = probs_sum <= 1 - top_p + # at least one + top_p_mask[..., -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_top_k_only( + logits: torch.Tensor, + top_k: int, +) -> torch.Tensor: + """Apply top-k mask to the logits. + + Simplified version of VLLM's implementation for scalar parameters. + This implementation doesn't involve sorting the entire vocab. + + Args: + logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + top_k: Top-k sampling parameter. + + Returns: + Filtered logits with top-k applied + """ + if top_k >= logits.shape[-1] or top_k == -1: + return logits + + # Get top-k values and create mask + top_k_values, _ = torch.topk(logits, top_k, dim=-1) + threshold = top_k_values[..., -1:].expand_as(logits) + mask = logits >= threshold + + # Apply mask: keep top-k values, set others to -inf + logits = torch.where( + mask, + logits, + torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), + ) + return logits + + def is_vllm_v1_engine_enabled() -> bool: """Check if vLLM V1 engine is enabled. diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 5712985cd3..21cde239c5 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -16,12 +16,133 @@ import unittest.mock from unittest.mock import MagicMock, patch +import torch + from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, configure_expandable_segments, get_megatron_checkpoint_dir, ) +class TestSamplingUtils(unittest.TestCase): + """Test cases for sampling utility functions.""" + + def setUp(self): + """Set up deterministic test data""" + # Create a deterministic logits tensor of shape [2, 2, 4] + self.logits = torch.tensor( + [ + [ + [10.0, 5.0, 2.0, 1.0], + [8.0, 7.0, 3.0, 0.5], + ], + [ + [6.0, 9.0, 1.5, 4.0], + [2.5, 1.0, 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + def test_top_k_only(self): + """Test top-k only sampling""" + k = 2 + result = apply_top_k_top_p(self.logits.clone(), top_k=k) + + # Expected result: keep top 2 values, mask others + expected = torch.tensor( + [ + [ + [10.0, 5.0, -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [6.0, 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) + + def test_top_p_only(self): + """Test top-p only sampling with specific probability threshold""" + p = 0.8 + result = apply_top_k_top_p(self.logits.clone(), top_p=p) + + expected = torch.tensor( + [ + [ + [10.0, -float("inf"), -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [-float("inf"), 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) + + def test_top_k_and_top_p_combined(self): + """Test both top-k and top-p sampling together""" + k = 3 + p = 0.8 + result = apply_top_k_top_p(self.logits.clone(), top_k=k, top_p=p) + + # Expected: apply top-k=3 first, then top-p=0.8 + # For our data, top-k=3 keeps top 3 values, then top-p further filters + expected = torch.tensor( + [ + [ + [10.0, -float("inf"), -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [-float("inf"), 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) + + def test_no_sampling(self): + """Test no sampling (should return original logits)""" + result = apply_top_k_top_p(self.logits.clone()) + self.assertTrue(torch.allclose(result, self.logits, equal_nan=True)) + + def test_edge_cases(self): + """Test edge cases for sampling parameters""" + # k >= vocab size (should return original logits) + result = apply_top_k_top_p(self.logits.clone(), top_k=4) + self.assertTrue(torch.allclose(result, self.logits)) + + result = apply_top_k_top_p(self.logits.clone(), top_k=10) + self.assertTrue(torch.allclose(result, self.logits)) + + # p = 1.0 (should return original logits) + result = apply_top_k_top_p(self.logits.clone(), top_p=1.0) + self.assertTrue(torch.allclose(result, self.logits)) + + def test_gradient_flow(self): + """Test that gradients flow through the sampling operations""" + logits_with_grad = self.logits.clone().requires_grad_(True) + result = apply_top_k_top_p(logits_with_grad, top_k=2, top_p=0.7) + loss = result.sum() + loss.backward() + + # Check that gradients exist and are non-zero for at least some elements + self.assertIsNotNone(logits_with_grad.grad) + self.assertEqual(logits_with_grad.grad.shape, logits_with_grad.shape) + self.assertTrue(torch.any(logits_with_grad.grad != 0)) + + class TestConfigureExpandableSegments(unittest.TestCase): """Test cases for configure_expandable_segments function."""