From b3e0ef28489883b1dc19ce9242f4c14d1e912337 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Mon, 28 Jul 2025 13:24:35 -0700 Subject: [PATCH 1/5] feat: Add top_k_top_p util functions --- nemo_rl/models/policy/utils.py | 96 +++++++++++++++- tests/unit/models/policy/test_utils.py | 152 +++++++++++++++++++++++++ 2 files changed, 247 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index a61e5e20b7..6299f15494 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -14,7 +14,7 @@ import importlib import os -from typing import Any +from typing import Any, Optional import torch from transformers import AutoConfig @@ -22,6 +22,100 @@ from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches +def apply_top_k_top_p( + logits: torch.Tensor, + top_k: Optional[int] = None, + top_p: Optional[float] = None, +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + Simplified version of VLLM's implementation for scalar parameters. + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + + Args: + logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + top_k: Top-k sampling parameter. + top_p: Top-p (nucleus) sampling parameter. + + Returns: + Filtered logits with sampling parameters applied + """ + if top_p is None: + if top_k is None: + return logits + # Avoid sorting vocab for top-k only case + return apply_top_k_only(logits, top_k) + + # Apply top-p (requires sorting) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if top_k is not None: + # Apply top-k first + top_k_index = logits_sort.size(-1) - top_k + # Get all the top_k values - need to broadcast the index across all dimensions + index_tensor = torch.full( + logits_sort.shape[:-1], + top_k_index, + device=logits_sort.device, + dtype=torch.long, + ) + top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) + top_k_mask = logits_sort < top_k_threshold + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_p_mask = probs_sum <= 1 - top_p + # at least one - but for p=0.0, we want exactly one + if top_p == 0.0: + # Keep only the highest probability token + top_p_mask = torch.ones_like(top_p_mask, dtype=torch.bool) + top_p_mask[..., 0] = False # Keep only the first (highest prob) token + else: + # at least one + top_p_mask[..., -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_top_k_only( + logits: torch.Tensor, + top_k: int, +) -> torch.Tensor: + """Apply top-k mask to the logits. + + Simplified version of VLLM's implementation for scalar parameters. + This implementation doesn't involve sorting the entire vocab. + + Args: + logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + top_k: Top-k sampling parameter. + + Returns: + Filtered logits with top-k applied + """ + if top_k >= logits.shape[-1]: + return logits + + # Get top-k values and create mask + top_k_values, _ = torch.topk(logits, top_k, dim=-1) + threshold = top_k_values[..., -1:].expand_as(logits) + mask = logits >= threshold + + # Apply mask: keep top-k values, set others to -inf + logits = torch.where( + mask, + logits, + torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), + ) + return logits + + def is_vllm_v1_engine_enabled() -> bool: """Check if vLLM V1 engine is enabled. diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 5712985cd3..0998167da8 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -16,12 +16,164 @@ import unittest.mock from unittest.mock import MagicMock, patch +import torch +import torch.nn.functional as F + from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, configure_expandable_segments, get_megatron_checkpoint_dir, ) +def manual_top_k(logits, k): + """Manual reference implementation for top-k""" + top_k_values, _ = torch.topk(logits, k, dim=-1) + threshold = top_k_values[..., -1:].expand_as(logits) + mask = logits >= threshold + return torch.where( + mask, + logits, + torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), + ) + + +def manual_top_p(logits, p): + """Manual reference implementation for top-p""" + probs = F.softmax(logits, dim=-1) + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + mask = cumulative_probs <= p + # Always keep at least one + mask[..., 0] = True + keep_indices = torch.zeros_like(logits, dtype=torch.bool) + for b in range(logits.shape[0]): + for s in range(logits.shape[1]): + keep = sorted_indices[b, s][mask[b, s]] + keep_indices[b, s, keep] = True + return torch.where( + keep_indices, + logits, + torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), + ) + + +class TestSamplingUtils(unittest.TestCase): + """Test cases for sampling utility functions.""" + + def test_top_k_only(self): + """Test top-k only sampling""" + torch.manual_seed(0) + logits = torch.randn(2, 3, 10) + k = 4 + result = apply_top_k_top_p(logits.clone(), top_k=k) + expected = manual_top_k(logits, k) + self.assertTrue(torch.allclose(result, expected)) + # Check only k values per row are not -inf + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf == k)) + + def test_top_p_only(self): + """Test top-p only sampling""" + torch.manual_seed(1) + logits = torch.randn(2, 3, 10) + p = 0.7 + result = apply_top_k_top_p(logits.clone(), top_p=p) + + # Check that we have at least one non-inf value per position + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf >= 1)) + + # Check that probability sums to 1.0 + result_probs = F.softmax(result, dim=-1) + result_sum = result_probs.sum(dim=-1) + self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) + + def test_top_k_and_top_p(self): + """Test both top-k and top-p sampling""" + torch.manual_seed(2) + logits = torch.randn(2, 3, 10) + k = 5 + p = 0.8 + result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p) + + # Check basic properties + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf >= 1)) + self.assertTrue(torch.all(non_inf <= k)) + + # Check that probability sums to 1.0 + result_probs = F.softmax(result, dim=-1) + result_sum = result_probs.sum(dim=-1) + self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) + + def test_no_sampling(self): + """Test no sampling (should return original logits)""" + torch.manual_seed(3) + logits = torch.randn(2, 3, 10) + result = apply_top_k_top_p(logits.clone()) + self.assertTrue(torch.allclose(result, logits)) + + def test_edge_cases(self): + """Test edge cases for sampling parameters""" + torch.manual_seed(4) + logits = torch.randn(2, 3, 10) + + # k = vocab size (should return original logits) + result = apply_top_k_top_p(logits.clone(), top_k=10) + self.assertTrue(torch.allclose(result, logits)) + + # k = 1 (should keep only one token) + result = apply_top_k_top_p(logits.clone(), top_k=1) + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf == 1)) + + # p = 1.0 (should return original logits) + result = apply_top_k_top_p(logits.clone(), top_p=1.0) + self.assertTrue(torch.allclose(result, logits)) + + # p = 0.0 (should keep only the highest probability token) + result = apply_top_k_top_p(logits.clone(), top_p=0.0) + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf == 1)) + + # k > vocab size (should return original logits) + result = apply_top_k_top_p(logits.clone(), top_k=20) + self.assertTrue(torch.allclose(result, logits)) + + def test_gradient_flow(self): + """Test that gradients flow through the sampling operations""" + torch.manual_seed(5) + logits = torch.randn(2, 3, 10, requires_grad=True) + result = apply_top_k_top_p(logits, top_k=3, top_p=0.7) + loss = result.sum() + loss.backward() + + # Check that gradients exist and are non-zero + self.assertIsNotNone(logits.grad) + self.assertEqual(logits.grad.shape, logits.shape) + self.assertTrue(torch.any(logits.grad != 0)) + + def test_large_vocab(self): + """Test with larger vocabulary size""" + torch.manual_seed(6) + logits = torch.randn(1, 2, 1000) # Large vocab + k = 50 + p = 0.9 + + result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p) + + # Check basic properties + non_inf = (result != -float("inf")).sum(dim=-1) + self.assertTrue(torch.all(non_inf >= 1)) + self.assertTrue(torch.all(non_inf <= k)) + + # Check that probability sums to 1.0 + result_probs = F.softmax(result, dim=-1) + result_sum = result_probs.sum(dim=-1) + self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) + + class TestConfigureExpandableSegments(unittest.TestCase): """Test cases for configure_expandable_segments function.""" From cf703f6d783d2d3530d3432ddd48806f999e097a Mon Sep 17 00:00:00 2001 From: Zhanda Date: Mon, 28 Jul 2025 13:49:46 -0700 Subject: [PATCH 2/5] feat: Support top-k and top-p sampling for dtensor with vLLM v0 --- .../models/policy/dtensor_policy_worker.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index b590032408..97817a5c98 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -61,6 +61,7 @@ ReferenceLogprobOutputSpec, ) from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, configure_expandable_segments, get_gpu_info, get_runtime_env_for_policy_worker, @@ -417,19 +418,36 @@ def create_context_parallel_ctx( no_restore_buffers=cp_no_restore_buffers, ) - # Refer to nemo impl. Below is original comment. - # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 - - def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: - # Apply temperature scaling to logits if configured and not using V1 engine. + def _post_processing_logits_with_sampling_params( + self, logits: torch.Tensor + ) -> torch.Tensor: + # Apply temperature scaling and top-k/top-p sampling to logits if configured and not using V1 engine. if "generation" in self.cfg and self.cfg["generation"] is not None: # The V1 engine returns raw logits before temperature scaling. # The V0 engine returns scaled logits. # Therefore, we only divide if we are not using the V1 engine. if not is_vllm_v1_engine_enabled(): + # Convert to float32 to mimic vllm's behavior. + logits = logits.to(torch.float32) + + # Apply temperature scaling. logits.div_(self.cfg["generation"]["temperature"]) + + # Apply top-k and top-p sampling + if self.tp_size == 1: + top_p = self.cfg["generation"]["top_p"] + top_k = self.cfg["generation"]["top_k"] + logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) + else: + raise ValueError( + "Top-k and top-p sampling is not supported for tensor_parallel_size > 1 for vLLM v0 engine. " + "Please use vLLM v1 engine instead." + ) + return logits + # Refer to nemo impl. Below is original comment. + # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 @staticmethod @contextlib.contextmanager def train_context(cp_context: Optional[Generator[None, None, None]] = None): @@ -666,8 +684,10 @@ def train( else: logits = outputs.logits - # Apply temperature scaling - logits = self._apply_temperature_scaling(logits) + # Apply temperature scaling and sampling parameters + logits = self._post_processing_logits_with_sampling_params( + logits + ) if self.cp_size > 1: seq_index_dtensor = ( @@ -947,8 +967,8 @@ def get_logprobs( logits = outputs.logits - # Apply temperature scaling - logits = self._apply_temperature_scaling(logits) + # Apply temperature scaling and sampling parameters + logits = self._post_processing_logits_with_sampling_params(logits) if self.cp_size > 1: seq_index_tensor = ( From 2c8b5ec94643502b00f2315b45e1c0ba8b5e825d Mon Sep 17 00:00:00 2001 From: Zhanda Date: Mon, 28 Jul 2025 14:01:11 -0700 Subject: [PATCH 3/5] fix: Consider the cases where top_p and top_k are trivial values --- nemo_rl/models/policy/dtensor_policy_worker.py | 8 ++++++-- nemo_rl/models/policy/utils.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 97817a5c98..313379f16d 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -434,9 +434,13 @@ def _post_processing_logits_with_sampling_params( logits.div_(self.cfg["generation"]["temperature"]) # Apply top-k and top-p sampling + top_k = self.cfg["generation"]["top_k"] + top_p = self.cfg["generation"]["top_p"] + # Skip if no sampling is configured + if (top_k is None or top_k == -1) and (top_p is None or top_p == 1.0): + return logits + if self.tp_size == 1: - top_p = self.cfg["generation"]["top_p"] - top_k = self.cfg["generation"]["top_k"] logits = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) else: raise ValueError( diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 6299f15494..98eca0aa0f 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -41,8 +41,8 @@ def apply_top_k_top_p( Returns: Filtered logits with sampling parameters applied """ - if top_p is None: - if top_k is None: + if top_p is None or top_p == 1.0: + if top_k is None or top_k == -1: return logits # Avoid sorting vocab for top-k only case return apply_top_k_only(logits, top_k) @@ -50,7 +50,7 @@ def apply_top_k_top_p( # Apply top-p (requires sorting) logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - if top_k is not None: + if top_k is not None and top_k > 0: # Apply top-k first top_k_index = logits_sort.size(-1) - top_k # Get all the top_k values - need to broadcast the index across all dimensions From dd62111c80e3a00ccb2220cb504a5c7293fd9d7c Mon Sep 17 00:00:00 2001 From: Zhanda Date: Mon, 28 Jul 2025 14:03:03 -0700 Subject: [PATCH 4/5] fix: Consider the cases where top_k is trivial values for apply_top_k_only --- nemo_rl/models/policy/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 98eca0aa0f..8d2dde5570 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -99,7 +99,7 @@ def apply_top_k_only( Returns: Filtered logits with top-k applied """ - if top_k >= logits.shape[-1]: + if top_k >= logits.shape[-1] or top_k == -1: return logits # Get top-k values and create mask From 58c1e00d93dace2b0f34a7a962a22004aa5a868c Mon Sep 17 00:00:00 2001 From: Zhanda Date: Tue, 5 Aug 2025 12:24:39 -0700 Subject: [PATCH 5/5] fix: Improve the test cases --- nemo_rl/models/policy/utils.py | 16 +- tests/unit/models/policy/test_utils.py | 213 +++++++++++-------------- 2 files changed, 96 insertions(+), 133 deletions(-) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 8d2dde5570..cf0df8556c 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -35,8 +35,8 @@ def apply_top_k_top_p( Args: logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] - top_k: Top-k sampling parameter. - top_p: Top-p (nucleus) sampling parameter. + top_k: Top-k sampling parameter. Set to -1 to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. Returns: Filtered logits with sampling parameters applied @@ -50,7 +50,7 @@ def apply_top_k_top_p( # Apply top-p (requires sorting) logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - if top_k is not None and top_k > 0: + if top_k is not None and top_k != -1: # Apply top-k first top_k_index = logits_sort.size(-1) - top_k # Get all the top_k values - need to broadcast the index across all dimensions @@ -68,14 +68,8 @@ def apply_top_k_top_p( probs_sort = logits_sort.softmax(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1) top_p_mask = probs_sum <= 1 - top_p - # at least one - but for p=0.0, we want exactly one - if top_p == 0.0: - # Keep only the highest probability token - top_p_mask = torch.ones_like(top_p_mask, dtype=torch.bool) - top_p_mask[..., 0] = False # Keep only the first (highest prob) token - else: - # at least one - top_p_mask[..., -1] = False + # at least one + top_p_mask[..., -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 0998167da8..21cde239c5 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock, patch import torch -import torch.nn.functional as F from nemo_rl.models.policy.utils import ( apply_top_k_top_p, @@ -26,152 +25,122 @@ ) -def manual_top_k(logits, k): - """Manual reference implementation for top-k""" - top_k_values, _ = torch.topk(logits, k, dim=-1) - threshold = top_k_values[..., -1:].expand_as(logits) - mask = logits >= threshold - return torch.where( - mask, - logits, - torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), - ) - - -def manual_top_p(logits, p): - """Manual reference implementation for top-p""" - probs = F.softmax(logits, dim=-1) - sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - mask = cumulative_probs <= p - # Always keep at least one - mask[..., 0] = True - keep_indices = torch.zeros_like(logits, dtype=torch.bool) - for b in range(logits.shape[0]): - for s in range(logits.shape[1]): - keep = sorted_indices[b, s][mask[b, s]] - keep_indices[b, s, keep] = True - return torch.where( - keep_indices, - logits, - torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), - ) - - class TestSamplingUtils(unittest.TestCase): """Test cases for sampling utility functions.""" + def setUp(self): + """Set up deterministic test data""" + # Create a deterministic logits tensor of shape [2, 2, 4] + self.logits = torch.tensor( + [ + [ + [10.0, 5.0, 2.0, 1.0], + [8.0, 7.0, 3.0, 0.5], + ], + [ + [6.0, 9.0, 1.5, 4.0], + [2.5, 1.0, 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + def test_top_k_only(self): """Test top-k only sampling""" - torch.manual_seed(0) - logits = torch.randn(2, 3, 10) - k = 4 - result = apply_top_k_top_p(logits.clone(), top_k=k) - expected = manual_top_k(logits, k) - self.assertTrue(torch.allclose(result, expected)) - # Check only k values per row are not -inf - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf == k)) + k = 2 + result = apply_top_k_top_p(self.logits.clone(), top_k=k) + + # Expected result: keep top 2 values, mask others + expected = torch.tensor( + [ + [ + [10.0, 5.0, -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [6.0, 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) def test_top_p_only(self): - """Test top-p only sampling""" - torch.manual_seed(1) - logits = torch.randn(2, 3, 10) - p = 0.7 - result = apply_top_k_top_p(logits.clone(), top_p=p) - - # Check that we have at least one non-inf value per position - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf >= 1)) - - # Check that probability sums to 1.0 - result_probs = F.softmax(result, dim=-1) - result_sum = result_probs.sum(dim=-1) - self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) - - def test_top_k_and_top_p(self): - """Test both top-k and top-p sampling""" - torch.manual_seed(2) - logits = torch.randn(2, 3, 10) - k = 5 + """Test top-p only sampling with specific probability threshold""" p = 0.8 - result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p) + result = apply_top_k_top_p(self.logits.clone(), top_p=p) + + expected = torch.tensor( + [ + [ + [10.0, -float("inf"), -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [-float("inf"), 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) + + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) - # Check basic properties - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf >= 1)) - self.assertTrue(torch.all(non_inf <= k)) + def test_top_k_and_top_p_combined(self): + """Test both top-k and top-p sampling together""" + k = 3 + p = 0.8 + result = apply_top_k_top_p(self.logits.clone(), top_k=k, top_p=p) + + # Expected: apply top-k=3 first, then top-p=0.8 + # For our data, top-k=3 keeps top 3 values, then top-p further filters + expected = torch.tensor( + [ + [ + [10.0, -float("inf"), -float("inf"), -float("inf")], + [8.0, 7.0, -float("inf"), -float("inf")], + ], + [ + [-float("inf"), 9.0, -float("inf"), -float("inf")], + [-float("inf"), -float("inf"), 8.5, 7.5], + ], + ], + dtype=torch.float32, + ) - # Check that probability sums to 1.0 - result_probs = F.softmax(result, dim=-1) - result_sum = result_probs.sum(dim=-1) - self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) + self.assertTrue(torch.allclose(result, expected, equal_nan=True)) def test_no_sampling(self): """Test no sampling (should return original logits)""" - torch.manual_seed(3) - logits = torch.randn(2, 3, 10) - result = apply_top_k_top_p(logits.clone()) - self.assertTrue(torch.allclose(result, logits)) + result = apply_top_k_top_p(self.logits.clone()) + self.assertTrue(torch.allclose(result, self.logits, equal_nan=True)) def test_edge_cases(self): """Test edge cases for sampling parameters""" - torch.manual_seed(4) - logits = torch.randn(2, 3, 10) + # k >= vocab size (should return original logits) + result = apply_top_k_top_p(self.logits.clone(), top_k=4) + self.assertTrue(torch.allclose(result, self.logits)) - # k = vocab size (should return original logits) - result = apply_top_k_top_p(logits.clone(), top_k=10) - self.assertTrue(torch.allclose(result, logits)) - - # k = 1 (should keep only one token) - result = apply_top_k_top_p(logits.clone(), top_k=1) - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf == 1)) + result = apply_top_k_top_p(self.logits.clone(), top_k=10) + self.assertTrue(torch.allclose(result, self.logits)) # p = 1.0 (should return original logits) - result = apply_top_k_top_p(logits.clone(), top_p=1.0) - self.assertTrue(torch.allclose(result, logits)) - - # p = 0.0 (should keep only the highest probability token) - result = apply_top_k_top_p(logits.clone(), top_p=0.0) - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf == 1)) - - # k > vocab size (should return original logits) - result = apply_top_k_top_p(logits.clone(), top_k=20) - self.assertTrue(torch.allclose(result, logits)) + result = apply_top_k_top_p(self.logits.clone(), top_p=1.0) + self.assertTrue(torch.allclose(result, self.logits)) def test_gradient_flow(self): """Test that gradients flow through the sampling operations""" - torch.manual_seed(5) - logits = torch.randn(2, 3, 10, requires_grad=True) - result = apply_top_k_top_p(logits, top_k=3, top_p=0.7) + logits_with_grad = self.logits.clone().requires_grad_(True) + result = apply_top_k_top_p(logits_with_grad, top_k=2, top_p=0.7) loss = result.sum() loss.backward() - # Check that gradients exist and are non-zero - self.assertIsNotNone(logits.grad) - self.assertEqual(logits.grad.shape, logits.shape) - self.assertTrue(torch.any(logits.grad != 0)) - - def test_large_vocab(self): - """Test with larger vocabulary size""" - torch.manual_seed(6) - logits = torch.randn(1, 2, 1000) # Large vocab - k = 50 - p = 0.9 - - result = apply_top_k_top_p(logits.clone(), top_k=k, top_p=p) - - # Check basic properties - non_inf = (result != -float("inf")).sum(dim=-1) - self.assertTrue(torch.all(non_inf >= 1)) - self.assertTrue(torch.all(non_inf <= k)) - - # Check that probability sums to 1.0 - result_probs = F.softmax(result, dim=-1) - result_sum = result_probs.sum(dim=-1) - self.assertTrue(torch.allclose(result_sum, torch.ones_like(result_sum))) + # Check that gradients exist and are non-zero for at least some elements + self.assertIsNotNone(logits_with_grad.grad) + self.assertEqual(logits_with_grad.grad.shape, logits_with_grad.shape) + self.assertTrue(torch.any(logits_with_grad.grad != 0)) class TestConfigureExpandableSegments(unittest.TestCase):