From ce6c0f35e7abfc35830a36d762f81bebe3fef2e4 Mon Sep 17 00:00:00 2001 From: Zhanda Date: Thu, 12 Feb 2026 12:28:32 -0800 Subject: [PATCH 01/20] Reimplement top-p and top-k from #1578 (not cp loss_function.py common.py megatron_policy_worker.py) Signed-off-by: Zhanda --- ....2-1b-instruct-1n8g-megatron-sampling.yaml | 37 + ...3.2-1b-instruct-1n8g-megatron-temp0.6.yaml | 35 + nemo_rl/distributed/model_utils.py | 716 ++++++++++++++++-- .../models/generation/vllm/vllm_generation.py | 30 +- nemo_rl/models/policy/utils.py | 286 +++++-- .../policy/workers/dtensor_policy_worker.py | 63 +- .../workers/dtensor_policy_worker_v2.py | 46 ++ ...a3.2-1b-instruct-1n8g-megatron-sampling.sh | 44 ++ ...ma3.2-1b-instruct-1n8g-megatron-temp0.6.sh | 44 ++ tests/test_suites/nightly.txt | 2 + tests/unit/distributed/test_model_utils.py | 301 ++++++++ .../models/generation/test_vllm_generation.py | 37 - .../generation/test_vllm_logprobs_mode.py | 4 +- 13 files changed, 1463 insertions(+), 182 deletions(-) create mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml create mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml create mode 100755 tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh create mode 100755 tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml new file mode 100644 index 0000000000..34d76a7421 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml @@ -0,0 +1,37 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + temperature: 0.8 + top_p: 0.9 + top_k: 50 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml new file mode 100644 index 0000000000..82b5639f73 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml @@ -0,0 +1,35 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + temperature: 0.6 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 277e7a6209..9020814bb2 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -17,6 +17,12 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor +from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + need_top_k_or_top_p_filtering, +) + @torch.no_grad() def _compute_distributed_log_softmax( @@ -262,6 +268,345 @@ def backward( return grad_input, None, None, None, None, None, None +class DistributedLogprobWithSampling(torch.autograd.Function): + """Custom autograd function for computing log probabilities with top-k/top-p sampling. + + This function materializes the full vocabulary by converting from vocab-parallel to + batch-sequence-parallel layout, applies filtering, and computes log probabilities. + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + top_k: int | None, + top_p: float, + inference_only: bool = False, + ) -> torch.Tensor: + """Forward pass for sampling-based logprob computation. + + Args: + vocab_parallel_logits: [B, S, V_local] logits sharded by vocab + target: [B, S] target token ids (already shifted) + tp_group: Tensor parallel process group + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + inference_only: If True, don't save tensors for backward + + Returns: + Log probabilities [B, S] + """ + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + if BS % world_size != 0: + raise ValueError( + f"B*S={BS} must be divisible by tensor parallel size {world_size} when using top-p/top-k sampling. " + "Please set policy.make_sequence_length_divisible_by to tensor parallel size." + ) + BS_local = BS // world_size + + # Reshape to 2D for all_to_all + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + + # Flatten target: [B, S] -> [BS] + target_flat = target.flatten() # [BS] + + # Extract local portion + start_idx = rank * BS_local + end_idx = (rank + 1) * BS_local + target_local = target_flat[start_idx:end_idx] # [BS_local] + + # All-to-all to get batch-sequence parallel logits + seq_parallel_logits = all_to_all_vp2sq(reshaped_vocab_parallel_logits, tp_group) + + # Apply top-k and top-p filtering locally (returns keep_mask for gradient) + logits, keep_mask = apply_top_k_top_p( + seq_parallel_logits, top_k=top_k, top_p=top_p + ) + + # Compute log softmax + log_probs = torch.nn.functional.log_softmax( + logits.to(dtype=torch.float32), dim=-1 + ) + + # Gather log probs for target tokens + token_logprobs = torch.gather( + log_probs, -1, target_local.unsqueeze(-1) + ).squeeze(-1) + + # All-gather across TP to get full sequence [BS] + gathered_list = [torch.empty_like(token_logprobs) for _ in range(world_size)] + torch.distributed.all_gather(gathered_list, token_logprobs, group=tp_group) + token_logprobs = torch.cat(gathered_list, dim=0) # [BS] + + # Reshape back to [B, S] + token_logprobs = token_logprobs.view(B, S) + + if not inference_only: + # Save softmax and mask for backward + softmax_output = log_probs.exp() + ctx.save_for_backward(softmax_output, target_local, keep_mask) + ctx.tp_group = tp_group + ctx.world_size = world_size + ctx.rank = rank + ctx.BS_local = BS_local + ctx.B = B + ctx.S = S + + return token_logprobs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None]: + """Backward pass for sampling-based logprob computation.""" + grad_output = grad_outputs[0] # [B, S] + softmax_output, target_local, keep_mask = ctx.saved_tensors + tp_group = ctx.tp_group + world_size = ctx.world_size + rank = ctx.rank + BS_local = ctx.BS_local + B = ctx.B + S = ctx.S + + # Flatten and extract local portion + grad_output_flat = grad_output.flatten() # [BS] + start_idx = rank * BS_local + end_idx = (rank + 1) * BS_local + grad_output_local = grad_output_flat[start_idx:end_idx] # [BS_local] + + # Compute gradient + V = softmax_output.shape[-1] + is_chosen = torch.nn.functional.one_hot(target_local, num_classes=V) + grad_logits_local = is_chosen.float().sub_(softmax_output) + grad_logits_local.mul_(grad_output_local.unsqueeze(-1)) + + # Apply keep_mask to gradients - filtered tokens don't get gradients + if keep_mask is not None: + grad_logits_local.mul_(keep_mask) + + # Convert back to vocab-parallel: [BS_local, V] -> [BS, V_local] + grad_vocab_parallel = all_to_all_sq2vp(grad_logits_local, tp_group) + grad_vocab_parallel = grad_vocab_parallel.view(B, S, V // world_size) + + return grad_vocab_parallel, None, None, None, None, None + + +class ChunkedDistributedLogprobWithSampling(torch.autograd.Function): + """Chunked version of DistributedLogprobWithSampling for memory efficiency. + + Uses delayed rematerialization to avoid storing large intermediate tensors. + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + top_k: int | None, + top_p: float, + chunk_size: int, + inference_only: bool = False, + ) -> torch.Tensor: + """Forward pass with chunked processing. + + Args: + vocab_parallel_logits: [B, S, V_local] logits sharded by vocab + target: [B, S] target token ids (already shifted) + tp_group: Tensor parallel process group + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + chunk_size: Chunk size for memory efficiency (in sequence dimension) + inference_only: If True, don't save tensors for backward + + Returns: + Log probabilities [B, S] + """ + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + if BS % world_size != 0: + raise ValueError( + f"B*S={BS} must be divisible by tensor parallel size {world_size} when using top-p/top-k sampling. " + "Please set policy.make_sequence_length_divisible_by to tensor parallel size." + ) + + # Convert chunk_size from sequence dimension to batch-sequence dimension + effective_chunk_size = chunk_size * B + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + + # Make sure the effective chunk size is divisible by the world size + # This ensure all the chunks (including the last one) meet the world size requirement. + if effective_chunk_size % world_size != 0: + raise ValueError( + f"Effective chunk size {effective_chunk_size} = chunk_size {chunk_size} * B {B} must be divisible " + f"by the tensor parallel size {world_size}." + ) + + # Flatten target: [B, S] -> [BS] + target_flat = target.flatten() # [BS] + + # Process in chunks + num_chunks = (BS + effective_chunk_size - 1) // effective_chunk_size + all_token_logprobs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * effective_chunk_size + chunk_end = min(BS, (chunk_idx + 1) * effective_chunk_size) + current_chunk_size = chunk_end - chunk_start + local_chunk_size = current_chunk_size // world_size + + # Slice the chunk + vocab_parallel_logits_chunk = reshaped_vocab_parallel_logits[ + chunk_start:chunk_end, : + ] + + # Extract target chunk for this rank + target_chunk = target_flat[chunk_start:chunk_end] + target_local = target_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # All-to-all to get batch-sequence parallel logits + seq_parallel_logits_chunk = all_to_all_vp2sq( + vocab_parallel_logits_chunk, tp_group + ) + + # Apply top-k and top-p filtering locally + logits_chunk, _ = apply_top_k_top_p( + seq_parallel_logits_chunk, top_k=top_k, top_p=top_p + ) + + # Compute log softmax + log_probs_chunk = torch.nn.functional.log_softmax( + logits_chunk.to(dtype=torch.float32), dim=-1 + ) + + # Gather log probs for target tokens in this chunk + token_logprobs_chunk = torch.gather( + log_probs_chunk, -1, target_local.unsqueeze(-1) + ).squeeze(-1) + + # All-gather across TP to get full chunk [current_chunk_size] + gathered_list = [ + torch.empty_like(token_logprobs_chunk) for _ in range(world_size) + ] + torch.distributed.all_gather( + gathered_list, token_logprobs_chunk, group=tp_group + ) + token_logprobs_chunk = torch.cat(gathered_list, dim=0) + + all_token_logprobs.append(token_logprobs_chunk) + + # Concatenate all chunks and reshape + token_logprobs = torch.cat(all_token_logprobs, dim=0) # [BS] + token_logprobs = token_logprobs.view(B, S) + + if not inference_only: + ctx.save_for_backward(vocab_parallel_logits) + ctx.target = target + ctx.tp_group = tp_group + ctx.top_k = top_k + ctx.top_p = top_p + ctx.chunk_size = chunk_size + + return token_logprobs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + """Backward pass with chunked rematerialization.""" + grad_output = grad_outputs[0] # [B, S] + (vocab_parallel_logits,) = ctx.saved_tensors + target = ctx.target + tp_group = ctx.tp_group + top_k = ctx.top_k + top_p = ctx.top_p + chunk_size = ctx.chunk_size + + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + effective_chunk_size = chunk_size * B + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + target_flat = target.flatten() + + num_chunks = (BS + effective_chunk_size - 1) // effective_chunk_size + grad_output_flat = grad_output.flatten() + + all_grad_chunks = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * effective_chunk_size + chunk_end = min(BS, (chunk_idx + 1) * effective_chunk_size) + current_chunk_size = chunk_end - chunk_start + local_chunk_size = current_chunk_size // world_size + + # Rematerialize forward pass for this chunk + vocab_parallel_logits_chunk = reshaped_vocab_parallel_logits[ + chunk_start:chunk_end, : + ] + + target_chunk = target_flat[chunk_start:chunk_end] + target_local = target_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # Rematerialize all-to-all + seq_parallel_logits_chunk = all_to_all_vp2sq( + vocab_parallel_logits_chunk, tp_group + ) + + # Rematerialize filtering + logits_chunk, keep_mask = apply_top_k_top_p( + seq_parallel_logits_chunk, top_k=top_k, top_p=top_p + ) + + # Rematerialize softmax + log_probs_chunk = torch.nn.functional.log_softmax( + logits_chunk.to(dtype=torch.float32), dim=-1 + ) + softmax_chunk = log_probs_chunk.exp() + + # Extract local portion of grad_output + grad_chunk = grad_output_flat[chunk_start:chunk_end] + grad_local = grad_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # Compute gradient: (one_hot - softmax) * grad_output + V = softmax_chunk.shape[-1] + is_chosen = torch.nn.functional.one_hot(target_local, num_classes=V) + grad_logits_local = is_chosen.float().sub_(softmax_chunk) + grad_logits_local.mul_(grad_local.unsqueeze(-1)) + + # Apply keep_mask + if keep_mask is not None: + grad_logits_local.mul_(keep_mask) + + # Convert back to vocab-parallel + grad_vocab_parallel_chunk = all_to_all_sq2vp(grad_logits_local, tp_group) + all_grad_chunks.append(grad_vocab_parallel_chunk) + + grad_vocab_parallel = torch.cat(all_grad_chunks, dim=0) + grad_vocab_parallel = grad_vocab_parallel.view(B, S, V_local) + + return grad_vocab_parallel, None, None, None, None, None, None + + class ChunkedDistributedGatherLogprob(torch.autograd.Function): """Compute distributed log-softmax once and gather logprobs at given global indices. @@ -388,6 +733,7 @@ def dtensor_from_parallel_logits_to_logprobs( inference_only: bool = False, seq_index: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -403,6 +749,7 @@ def dtensor_from_parallel_logits_to_logprobs( seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. chunk_size (Optional[int]): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -434,25 +781,51 @@ def dtensor_from_parallel_logits_to_logprobs( else: target = target.roll(shifts=-1, dims=-1) - if chunk_size is not None: - logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - chunk_size, - tp_group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # logprobs is sharded on the sequence dimension. @@ -473,6 +846,7 @@ def from_parallel_logits_to_logprobs( inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -487,6 +861,7 @@ def from_parallel_logits_to_logprobs( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -507,25 +882,51 @@ def from_parallel_logits_to_logprobs( cp_rank = torch.distributed.get_rank(cp_group) target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) - if chunk_size is not None: - logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - chunk_size, - tp_group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # we need to gather the logits by context parallelism @@ -550,6 +951,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. @@ -600,25 +1002,51 @@ def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation - if chunk_size is not None: - probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - chunk_size, - group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - group, - inference_only, - ).contiguous() + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + chunk_size, + group, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() # Remove batch dimension for filtering probs = probs.squeeze(0) @@ -778,6 +1206,7 @@ def get_logprobs_from_vocab_parallel_logits( input_ids: torch.Tensor | DTensor, seq_index: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -792,6 +1221,7 @@ def get_logprobs_from_vocab_parallel_logits( seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, with shape [sequence_length]. chunk_size (Optional[int]): Sequence dimension chunk size for computing log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities for the given input IDs. @@ -820,6 +1250,7 @@ def get_logprobs_from_vocab_parallel_logits( inference_only=not torch.is_grad_enabled(), seq_index=seq_index, chunk_size=chunk_size, + sampling_params=sampling_params, ) @@ -1272,3 +1703,180 @@ def backward( del softmax_output, log_probs, logits, H_local return grad_input, None, None, None + + +def compute_logprobs_from_logits( + next_token_logits: torch.Tensor, + input_ids: torch.Tensor, + vocab_parallel_rank: Optional[int], + vocab_parallel_group: Optional[torch.distributed.ProcessGroup], + context_parallel_group: Optional[torch.distributed.ProcessGroup], + seq_index: Optional[torch.Tensor], + sampling_params: Optional[TrainingSamplingParams], +) -> torch.Tensor: + """Compute token log-probabilities from logits, handling parallel and non-parallel cases. + + This function handles three cases: + 1. Vocab parallel (Megatron-style): uses from_parallel_logits_to_logprobs + 2. DTensor: uses get_logprobs_from_vocab_parallel_logits + 3. Non-parallel: applies top-k/top-p filtering, log_softmax, and gather + + Args: + next_token_logits: Logits tensor of shape [batch_size, seq_len, vocab_size] + input_ids: Input token IDs of shape [batch_size, seq_len] + vocab_parallel_rank: Rank in the vocab parallel group (required if vocab_parallel_group is provided) + vocab_parallel_group: Process group for vocab parallelism + context_parallel_group: Process group for context parallelism + seq_index: Sequence index tensor for DTensor path + sampling_params: Sampling parameters for top-k/top-p filtering + + Returns: + Token log-probabilities of shape [batch_size, seq_len - 1] + """ + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + token_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + sampling_params=sampling_params, + ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, + input_ids, + seq_index=seq_index, + sampling_params=sampling_params, + ) + else: + next_token_logits_wo_last = next_token_logits[ + :, :-1 + ] # Remove last position's logits + # Apply top-k and top-p filtering + next_token_logits_wo_last, _ = apply_top_k_top_p( + next_token_logits_wo_last, + top_k=sampling_params.top_k if sampling_params is not None else None, + top_p=sampling_params.top_p if sampling_params is not None else 1.0, + ) + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = input_ids[:, 1:].cuda() # Skip first token + token_logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + return token_logprobs + + +def all_to_all_vp2sq( + vocab_parallel_logits: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + """Convert vocab-parallel logits to batch-sequence-parallel logits via all-to-all. + + Note: This partitions the flattened B*S dimension, not just S. The input vocab_parallel_logits + need to be 2D tensor. + + Transforms [BS, V_local] -> [BS_local, V] where: + - V_local = V / tp_size (vocab is sharded) + - BS_local = BS / tp_size (batch-sequence will be sharded) + - Requires BS to be divisible by tp_size + + Args: + vocab_parallel_logits: [BS, V_local] tensor with vocab dimension sharded + tp_group: Tensor parallel process group + + Returns: + Batch-sequence-parallel logits [BS_local, V] with batch-sequence dimension sharded + """ + if vocab_parallel_logits.ndim != 2: + raise ValueError( + "For all_to_all_vp2sq, vocab_parallel_logits must be a 2D tensor, " + f"got {vocab_parallel_logits.ndim}D tensor with shape {vocab_parallel_logits.shape}" + ) + + world_size = torch.distributed.get_world_size(tp_group) + BS, V_local = vocab_parallel_logits.shape + + if BS % world_size != 0: + raise ValueError( + f"BS={BS} must be divisible by tensor parallel size {world_size}. " + f"Set policy.make_sequence_length_divisible_by to ensure divisibility." + ) + + BS_local = BS // world_size + + # Flatten and perform all-to-all: exchanges B*S chunks for vocab slices + input_flat = vocab_parallel_logits.flatten() + output_flat = torch.empty_like(input_flat) + torch.distributed.all_to_all_single(output_flat, input_flat, group=tp_group) + + # Rearrange output: merge vocab slices from all ranks into full vocabulary + # Equivalent to: "(w bs v) -> bs (w v)", w=world_size, bs=BS_local, v=V_local + output_tensor = output_flat.view(world_size, BS_local, V_local) + output_tensor = output_tensor.permute(1, 0, 2) + output_tensor = output_tensor.reshape(BS_local, world_size * V_local) + + return output_tensor + + +def all_to_all_sq2vp( + seq_parallel_logits: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + """Convert batch-sequence-parallel logits to vocab-parallel logits via all-to-all. + + Inverse operation of all_to_all_vp2sq. + + Transforms [BS_local, V] -> [BS, V_local] where: + - BS_local = BS / tp_size (batch-sequence is sharded) + - V_local = V / tp_size (vocab will be sharded) + + Args: + seq_parallel_logits: [BS_local, V] tensor with batch-sequence dimension sharded + tp_group: Tensor parallel process group + + Returns: + Vocab-parallel logits [BS, V_local] with vocab dimension sharded + """ + if seq_parallel_logits.ndim != 2: + raise ValueError( + "For all_to_all_sq2vp, seq_parallel_logits must be a 2D tensor, " + f"got {seq_parallel_logits.ndim}D tensor with shape {seq_parallel_logits.shape}" + ) + + world_size = torch.distributed.get_world_size(tp_group) + BS_local, V = seq_parallel_logits.shape + + if V % world_size != 0: + raise ValueError( + f"Vocabulary size {V} must be divisible by tensor parallel size {world_size}" + ) + + V_local = V // world_size + + # Rearrange input: split vocab into chunks for sending to different ranks + # Equivalent to: "bs (w v) -> (w bs v)", w=world_size, bs=BS_local, v=V_local + input_reshaped = seq_parallel_logits.view(BS_local, world_size, V_local) + input_permuted = input_reshaped.permute(1, 0, 2).contiguous() + input_flat = input_permuted.flatten() + + # Perform all-to-all: exchanges vocab slices for B*S chunks + output_flat = torch.empty_like(input_flat) + torch.distributed.all_to_all_single(output_flat, input_flat, group=tp_group) + + # Reshape output: merge B*S slices from all ranks into full batch-sequence dimension + output_tensor = output_flat.reshape(world_size * BS_local, V_local) + + return output_tensor diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 6138dfdb43..dfee1190c3 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -42,12 +42,6 @@ compute_spec_decode_metrics, ) -# Global thresholds for top_k and top_p validation. -# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible. -# See https://github.com/NVIDIA-NeMo/RL/issues/69 and https://github.com/NVIDIA-NeMo/RL/issues/237 for more details. -TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering) -TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0) - class VllmGeneration(GenerationInterface): def __init__( @@ -92,30 +86,16 @@ def __init__( ) # Validate sampling parameters early to avoid resource allocation with unsupported configs. - # The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly. - # However, we allow values above certain thresholds for token filtering purposes. - top_k = self.cfg["top_k"] - if top_k is not None and top_k != -1 and top_k < TOP_K_THRESHOLD: + top_k: int | None = self.cfg.get("top_k", None) + if top_k is not None and top_k != -1 and top_k < 1: raise ValueError( - ( - f"top_k sampling with values < {TOP_K_THRESHOLD} is not supported because the vLLM V1 engine " - "does not return logprobs after top_k filtering. Values >= {TOP_K_THRESHOLD} are allowed " - "for token filtering purposes. If you understand the implications and still want to use " - f"a lower top_k value, please manually comment out this check. Got top_k={top_k}. " - "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." - ) + f"top_k valid values: i) None or -1: no filtering. ii) >= 1: top-k filtering. Got top_k={top_k}." ) top_p: float = self.cfg.get("top_p", 1.0) - if top_p < TOP_P_THRESHOLD: + if top_p <= 0: raise ValueError( - ( - f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine " - "does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed " - "for token filtering purposes. If you understand the implications and still want to use " - f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. " - "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." - ) + f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." ) # Ensure all required VllmConfig fields are present diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 6061b0f143..97f011b704 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -15,6 +15,7 @@ import gc import os import traceback +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, cast @@ -75,6 +76,15 @@ } +# Default chunk size for top-k/top-p filtering. +# The sort operation in top-p filtering is memory intensive because it creates +# intermediate tensors of shape [bsz, seq_len, vocab_size] for both sorted values +# and indices. For large vocab sizes (e.g., 152K) and long sequences (e.g., 32K), +# this can cause OOM. Chunking along the sequence dimension reduces peak memory. +# Different chunk sizes have minor performance differences. +TOP_K_TOP_P_CHUNK_SIZE: int = 256 + + class IPCProtocol(Enum): """IPC protocol constants for ZMQ weight streaming.""" @@ -82,68 +92,44 @@ class IPCProtocol(Enum): ACK = "ack" -def apply_top_k_top_p( - logits: torch.Tensor, - top_k: Optional[int] = None, - top_p: Optional[float] = None, -) -> torch.Tensor: - """Apply top-k and top-p masks to the logits. +def need_top_k_filtering(top_k: int | None) -> bool: + """Check if top-k filtering is needed.""" + return top_k is not None and top_k != -1 - Simplified version of VLLM's implementation for scalar parameters. - Based on VLLM's implementation: - https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py - SPDX-License-Identifier: Apache-2.0 - Copyright contributors to the vLLM project +def need_top_p_filtering(top_p: float | None) -> bool: + """Check if top-p filtering is needed.""" + return top_p is not None and top_p != 1.0 - Args: - logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] - top_k: Top-k sampling parameter. Set to -1 to consider all tokens. - top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. - Returns: - Filtered logits with sampling parameters applied - """ - if top_p is None or top_p == 1.0: - if top_k is None or top_k == -1: - return logits - # Avoid sorting vocab for top-k only case - return apply_top_k_only(logits, top_k) - - # Apply top-p (requires sorting) - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if top_k is not None and top_k != -1: - # Apply top-k first - top_k_index = logits_sort.size(-1) - top_k - # Get all the top_k values - need to broadcast the index across all dimensions - index_tensor = torch.full( - logits_sort.shape[:-1], - top_k_index, - device=logits_sort.device, - dtype=torch.long, - ) - top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) - top_k_mask = logits_sort < top_k_threshold - logits_sort.masked_fill_(top_k_mask, -float("inf")) +def need_top_k_or_top_p_filtering(top_k: int | None, top_p: float | None) -> bool: + """Check if top-k or top-p filtering is needed.""" + return need_top_k_filtering(top_k) or need_top_p_filtering(top_p) - # Apply top-p - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_p_mask = probs_sum <= 1 - top_p - # at least one - top_p_mask[..., -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - # Re-sort the probabilities - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits +@dataclass +class TrainingSamplingParams: + """Training-specific sampling parameters to match generation parameters. + Used to ensure consistency between training and inference by applying the same sampling strategy during + logprob computation. Not directly using vLLM's SamplingParams class to avoid dependency on vLLM in this env. -def apply_top_k_only( + Attributes: + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + temperature: Temperature for scaling logits (default: 1.0) + """ + + top_k: int | None = None + top_p: float = 1.0 + temperature: float = 1.0 + + +@torch.no_grad() +def _apply_top_k_only_fn( logits: torch.Tensor, - top_k: int, -) -> torch.Tensor: + top_k: int | None, +) -> tuple[torch.Tensor, torch.Tensor | None]: """Apply top-k mask to the logits. Simplified version of VLLM's implementation for scalar parameters. @@ -155,27 +141,207 @@ def apply_top_k_only( Copyright contributors to the vLLM project Args: - logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + logits: Input logits tensor of shape [*, vocab_size]. top_k: Top-k sampling parameter. Returns: - Filtered logits with top-k applied + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed. """ - if top_k >= logits.shape[-1] or top_k == -1: - return logits + if not need_top_k_filtering(top_k): + return logits, None # Get top-k values and create mask + assert top_k is not None # Type narrowing top_k_values, _ = torch.topk(logits, top_k, dim=-1) threshold = top_k_values[..., -1:].expand_as(logits) - mask = logits >= threshold + keep_mask = logits >= threshold # Apply mask: keep top-k values, set others to -inf logits = torch.where( - mask, + keep_mask, logits, torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), ) - return logits + return logits, keep_mask + + +@torch.no_grad() +def _apply_top_k_top_p_fn( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with chunking for memory efficiency. + + The sort operation in top-p filtering is memory intensive because it creates + intermediate tensors of shape [num_tokens, vocab_size] for both sorted values + and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. + This function flattens the input to 2D and processes in chunks along the token + dimension (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] + or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens + chunk_size: Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates + tokens to be kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_p_filtering(top_p): + if not need_top_k_filtering(top_k): + return logits, None + # Avoid sorting vocab for top-k only case + filtered_logits, top_k_keep_mask = _apply_top_k_only_fn(logits, top_k) + return filtered_logits, top_k_keep_mask + + # Save original shape and flatten to 2D for consistent chunking + original_shape = logits.shape + vocab_size = logits.shape[-1] + logits = logits.view(-1, vocab_size) # [*, vocab_size] -> [num_tokens, vocab_size] + num_tokens = logits.shape[0] + + chunk_size = chunk_size if chunk_size is not None else num_tokens + + # Pre-allocate output tensors + filtered_logits = torch.empty_like(logits) + keep_mask = torch.empty( + num_tokens, vocab_size, dtype=torch.bool, device=logits.device + ) + + for start_idx in range(0, num_tokens, chunk_size): + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_logits = logits[start_idx:end_idx, :] + + # Sort this chunk + logits_sort, logits_idx = chunk_logits.sort(dim=-1, descending=False) + top_k_keep_mask_chunk = None + + if need_top_k_filtering(top_k): + assert top_k is not None # Type narrowing + # Apply top-k first + top_k_index = logits_sort.size(-1) - top_k + index_tensor = torch.full( + logits_sort.shape[:-1], + top_k_index, + device=logits_sort.device, + dtype=torch.long, + ) + top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) + top_k_keep_mask_chunk = logits_sort >= top_k_threshold + logits_sort.masked_fill_(~top_k_keep_mask_chunk, -float("inf")) + + # Apply top-p + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_p_keep_mask_chunk = probs_sum > 1 - top_p + # at least one + top_p_keep_mask_chunk[..., -1] = True + logits_sort.masked_fill_(~top_p_keep_mask_chunk, -float("inf")) + + # Scatter back to original order + chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + if top_k_keep_mask_chunk is not None: + chunk_mask = torch.logical_and(top_k_keep_mask_chunk, top_p_keep_mask_chunk) + else: + chunk_mask = top_p_keep_mask_chunk + chunk_mask = chunk_mask.scatter(dim=-1, index=logits_idx, src=chunk_mask) + + # Store results + filtered_logits[start_idx:end_idx, :] = chunk_filtered + keep_mask[start_idx:end_idx, :] = chunk_mask + + # Restore original shape + filtered_logits = filtered_logits.view(original_shape) + keep_mask = keep_mask.view(original_shape) + + return filtered_logits, keep_mask + + +class ApplyTopKTopP(torch.autograd.Function): + """Autograd function for top-k and top-p filtering with proper gradient handling.""" + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx, + logits: torch.Tensor, + top_k: Optional[int], + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k/top-p filtering and save masks for backward. + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + """ + filtered_logits, keep_mask = _apply_top_k_top_p_fn( + logits, top_k, top_p, chunk_size + ) + + # Save masks for backward pass + ctx.save_for_backward(keep_mask) + + return filtered_logits, keep_mask + + @staticmethod + def backward(ctx, *grad_outputs: torch.Tensor): + """Backward pass: mask out gradients for filtered tokens.""" + grad_filtered_logits = grad_outputs[0] + (keep_mask,) = ctx.saved_tensors + + # Apply masks to gradients - masked out tokens should not receive gradients + if keep_mask is not None: + grad_filtered_logits = grad_filtered_logits.masked_fill(~keep_mask, 0.0) + + return grad_filtered_logits, None, None, None + + +def apply_top_k_top_p( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with proper gradient handling. + + Simplified version of VLLM's implementation for scalar parameters. + + When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. + Processing is done in chunks (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_k_or_top_p_filtering(top_k, top_p): + return logits, None + return ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) def resolve_model_class(model_name: str) -> Any: diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 661254da23..08a7eaa83a 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -71,8 +71,12 @@ ScoreOutputSpec, ) from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, configure_dynamo_cache, get_runtime_env_for_policy_worker, + need_top_k_filtering, + need_top_p_filtering, resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker @@ -172,6 +176,16 @@ def __init__( if "generation" in config and config["generation"] is not None: self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + if "generation" in self.cfg and self.cfg["generation"] is not None: + generation_cfg = self.cfg["generation"] + self.sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + else: + self.sampling_params = None + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. if not self.is_generation_colocated: @@ -477,8 +491,21 @@ def create_context_parallel_ctx( # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: - if "generation" in self.cfg and self.cfg["generation"] is not None: - logits.div_(self.cfg["generation"]["temperature"]) + if self.sampling_params is not None and self.sampling_params.temperature != 1.0: + logits.div_(self.sampling_params.temperature) + return logits + + def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: + """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" + if self.sampling_params is not None and ( + need_top_k_filtering(self.sampling_params.top_k) + or need_top_p_filtering(self.sampling_params.top_p) + ): + logits, _ = apply_top_k_top_p( + logits, + top_k=self.sampling_params.top_k, + top_p=self.sampling_params.top_p, + ) return logits @staticmethod @@ -1087,6 +1114,7 @@ def get_logprobs( input_ids_dtensor, seq_index_tensor, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) assert token_logprobs.shape[1] == seq_len - 1 @@ -1096,6 +1124,7 @@ def get_logprobs( logits, input_ids, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) else: if logprob_chunk_size is not None: @@ -1113,6 +1142,10 @@ def get_logprobs( chunk_logits = logits[ :, chunk_start:chunk_end, : ].to(torch.float32) + # Apply top-k and top-p filtering + chunk_logits = self._apply_top_k_top_p_filtering( + chunk_logits + ) log_probs = torch.nn.functional.log_softmax( chunk_logits, dim=-1 ) @@ -1120,7 +1153,9 @@ def get_logprobs( log_probs = torch.cat(chunked_log_probs, dim=1) del chunked_log_probs else: + # Apply top-k and top-p filtering logits = logits.to(torch.float32) + logits = self._apply_top_k_top_p_filtering(logits) log_probs = torch.nn.functional.log_softmax( logits, dim=-1 ) @@ -1615,8 +1650,10 @@ def get_topk_logits( def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. - On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references - On exit: Restores original references and re-flips cuda/cpu + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. + Also disables top-k/top-p filtering since the reference policy's distribution + is different from the current policy, making filtered logprobs incompatible. + On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ with torch.no_grad(): try: @@ -1630,11 +1667,29 @@ def use_reference_model(self) -> Generator[None, None, None]: val = to_local_if_dtensor(v) val.copy_(self.reference_model_state_dict[k]) + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, # Disable top-k + top_p=1.0, # Disable top-p + temperature=saved_sampling_params.temperature, # Keep temperature + ) + else: + self.sampling_params = None + # - self.model is the original reference_model, now on CUDA # - curr_state_dict is the train model, now on CPU yield finally: + # Restore sampling_params + self.sampling_params = saved_sampling_params + # Restore train model state_dict for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index f5adf24fac..b9ca7c39c8 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -68,7 +68,10 @@ ScoreOutputSpec, ) from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, get_runtime_env_for_policy_worker, + need_top_k_or_top_p_filtering, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import ( @@ -239,6 +242,16 @@ def __init__( # Initialize checkpoint manager self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + if "generation" in self.cfg and self.cfg["generation"] is not None: + generation_cfg = self.cfg["generation"] + self.sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + else: + self.sampling_params = None + # Validate configuration and prepare runtime settings runtime_config = validate_and_prepare_config( config=config, @@ -323,6 +336,24 @@ def __init__( _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config + def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: + if self.sampling_params is not None and self.sampling_params.temperature != 1.0: + logits.div_(self.sampling_params.temperature) + return logits + + def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: + """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" + sampling_params = self.sampling_params + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + logits, _ = apply_top_k_top_p( + logits, + top_k=sampling_params.top_k, + top_p=sampling_params.top_p, + ) + return logits + @wrap_with_nvtx_name("dtensor_policy_worker_v2/train") def train( self, @@ -796,9 +827,24 @@ def use_reference_model(self) -> Generator[None, None, None]: # - self.model is the original reference_model, now on CUDA # - curr_state_dict is the train model, now on CPU + + # Save and adjust sampling_params for reference model + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None + yield finally: + # Restore sampling_params + self.sampling_params = saved_sampling_params + # Restore train model state_dict for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh new file mode 100755 index 0000000000..1cdfad6fd3 --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 12.5' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh new file mode 100755 index 0000000000..21dd483626 --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 5bfb3aac58..bd05ad6def 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -23,6 +23,8 @@ tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 8637ad22fe..3b7f7ecc8f 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -20,7 +20,9 @@ from nemo_rl.distributed.model_utils import ( ChunkedDistributedGatherLogprob, ChunkedDistributedLogprob, + ChunkedDistributedLogprobWithSampling, DistributedLogprob, + DistributedLogprobWithSampling, _compute_distributed_log_softmax, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, @@ -34,6 +36,9 @@ ) from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, +) @ray.remote(num_gpus=1) @@ -956,3 +961,299 @@ def test_distributed_logprob_all_tests( finally: cluster.shutdown() + + +@ray.remote(num_gpus=1) +class SamplingParamsTestActor: + def __init__(self, tp_size, sharding): + self.tp_size = tp_size + self.sharding = sharding + self.env_vars = dict(os.environ) + torch.distributed.init_process_group(backend="nccl") + self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) + + def test_top_k_top_p_filtering_forward_backward(self, top_k, top_p): + """Test top-k and top-p filtering logic including backward pass.""" + batch_size = 2 + seq_len = 4 + vocab_size = 100 + + torch.manual_seed(42) + logits = torch.randn( + batch_size, seq_len, vocab_size, device="cuda", requires_grad=True + ) + + filtered_logits, keep_mask = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) + + # Test 1: Verify top-k filtering + if top_k is not None: + for b in range(batch_size): + for s in range(seq_len): + topk_vals, topk_indices = torch.topk(logits[b, s], k=top_k) + topk_mask = torch.zeros( + vocab_size, dtype=torch.bool, device=logits.device + ) + topk_mask[topk_indices] = True + assert torch.all(torch.isinf(filtered_logits[b, s][~topk_mask])), ( + "Values outside top-k should be -inf" + ) + if top_p == 1.0: + assert not torch.any( + torch.isinf(filtered_logits[b, s][topk_mask]) + ), "Top-k values should not be -inf when top_p=1.0" + non_inf_count = (~torch.isinf(filtered_logits[b, s])).sum().item() + assert non_inf_count <= top_k, ( + f"Non-inf count {non_inf_count} exceeds top_k {top_k}" + ) + + # Test 2: Verify top-p filtering + if top_p < 1.0: + for b in range(batch_size): + for s in range(seq_len): + if top_k is not None: + topk_vals, topk_indices = torch.topk(logits[b, s], k=top_k) + temp_logits = torch.full_like(logits[b, s], float("-inf")) + temp_logits[topk_indices] = topk_vals + else: + temp_logits = logits[b, s] + probs = torch.nn.functional.softmax(temp_logits, dim=-1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=0) + cutoff_idx = torch.where(cumsum_probs > top_p)[0] + if len(cutoff_idx) > 0: + cutoff_idx = cutoff_idx[0].item() + 1 + else: + cutoff_idx = len(sorted_probs) + kept_indices = sorted_indices[:cutoff_idx] + for idx in kept_indices: + if not torch.isinf(filtered_logits[b, s, idx]): + continue + raise AssertionError(f"Index {idx} in top-p should not be -inf") + + # Test 3: No filtering case + if top_k is None and top_p >= 1.0: + torch.testing.assert_close( + filtered_logits, logits.detach(), rtol=1e-5, atol=1e-5 + ) + + # Test 4: Valid probabilities + probs = torch.nn.functional.softmax(filtered_logits, dim=-1) + assert torch.all(probs >= 0) and torch.all(probs <= 1), "Invalid probabilities" + assert torch.allclose( + probs.sum(dim=-1), torch.ones(batch_size, seq_len, device="cuda"), atol=1e-5 + ), "Probabilities don't sum to 1" + + # Test 5: Verify keep_mask alignment with filtered logits + if keep_mask is not None: + non_inf_mask = ~torch.isinf(filtered_logits.detach()) + assert torch.equal(keep_mask, non_inf_mask), ( + f"keep_mask doesn't match non-inf positions in filtered_logits! " + f"Mismatch count: {(keep_mask != non_inf_mask).sum().item()} out of {keep_mask.numel()}" + ) + + # Test 6: Backward pass + torch.manual_seed(44) + output_grad = torch.randn_like(filtered_logits) + non_inf_mask = ~torch.isinf(filtered_logits.detach()) + expected_grad = output_grad * non_inf_mask.float() + filtered_logits.backward(output_grad) + actual_grad = logits.grad + torch.testing.assert_close(actual_grad, expected_grad, rtol=1e-5, atol=1e-5) + + return {"success": True, "error": None, "top_k": top_k, "top_p": top_p} + + def test_distributed_logprob_with_sampling(self, top_k, top_p, chunk_size): + """Test DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling.""" + tp_group = self.tp_group + tp_rank = torch.distributed.get_rank(tp_group) + + batch_size = 4 + seq_len = 16 + vocab_size = 256 + vocab_part_size = vocab_size // self.tp_size + vocab_start_index = tp_rank * vocab_part_size + vocab_end_index = (tp_rank + 1) * vocab_part_size + + torch.manual_seed(42) + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") + vocab_parallel_logits = ( + full_logits[:, :, vocab_start_index:vocab_end_index] + .clone() + .requires_grad_(True) + ) + + torch.manual_seed(43) + target = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + + # === Expected computation using full logits === + expected_logits_filtered, _ = apply_top_k_top_p( + full_logits.clone(), top_k=top_k, top_p=top_p + ) + expected_log_probs = torch.nn.functional.log_softmax( + expected_logits_filtered, dim=-1 + ) + expected_target_logprobs = torch.gather( + expected_log_probs, -1, target.unsqueeze(-1) + ).squeeze(-1) + + # === Actual computation using distributed function === + if chunk_size is None: + actual_logprobs = DistributedLogprobWithSampling.apply( + vocab_parallel_logits, + target, + tp_group, + top_k, + top_p, + False, + ) + else: + actual_logprobs = ChunkedDistributedLogprobWithSampling.apply( + vocab_parallel_logits, + target, + tp_group, + top_k, + top_p, + chunk_size, + False, + ) + + # === Forward pass validation === + torch.testing.assert_close( + actual_logprobs, expected_target_logprobs, rtol=1e-4, atol=1e-4 + ) + + # === Backward pass validation === + torch.manual_seed(44) + output_grad = torch.randn_like(actual_logprobs) + + expected_logits_filtered_grad = full_logits.clone().requires_grad_(True) + expected_logits_filtered_after_filter, _ = apply_top_k_top_p( + expected_logits_filtered_grad, top_k=top_k, top_p=top_p + ) + expected_log_probs_grad = torch.nn.functional.log_softmax( + expected_logits_filtered_after_filter, dim=-1 + ) + expected_target_logprobs_grad = torch.gather( + expected_log_probs_grad, -1, target.unsqueeze(-1) + ).squeeze(-1) + expected_target_logprobs_grad.backward(output_grad) + expected_grad = expected_logits_filtered_grad.grad[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + actual_logprobs.backward(output_grad) + actual_grad = vocab_parallel_logits.grad.clone() + torch.testing.assert_close(actual_grad, expected_grad, rtol=1e-4, atol=1e-4) + + return { + "success": True, + "error": None, + "top_k": top_k, + "top_p": top_p, + "chunk_size": chunk_size, + } + + +SAMPLING_PARAMS_TEST_ACTOR_FQN = ( + f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor" +) + + +@pytest.fixture +def register_sampling_params_test_actor(): + """Register the SamplingParamsTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + SAMPLING_PARAMS_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM + yield SAMPLING_PARAMS_TEST_ACTOR_FQN + if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "top_k,top_p", + [ + (None, 1.0), # No filtering + (10, 1.0), # Only top-k + (None, 0.9), # Only top-p + (10, 0.9), # Both top-k and top-p + ], +) +def test_sampling_params_top_k_top_p( + register_sampling_params_test_actor, tp_size, top_k, top_p +): + """Test top-k and top-p filtering logic.""" + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip( + f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}" + ) + cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True) + try: + actor_fqn = register_sampling_params_test_actor + sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) + builder = RayWorkerBuilder(actor_fqn, tp_size, sharding) + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + futures = worker_group.run_all_workers_single_data( + "test_top_k_top_p_filtering_forward_backward", top_k=top_k, top_p=top_p + ) + results = ray.get(futures) + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + worker_group.shutdown(force=True) + finally: + cluster.shutdown() + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize( + "top_k,top_p", + [ + (10, 1.0), # Only top-k + (None, 0.9), # Only top-p + (10, 0.9), # Both top-k and top-p + ], +) +@pytest.mark.parametrize("chunk_size", [None, 4]) +def test_sampling_params_distributed_logprob( + register_sampling_params_test_actor, tp_size, top_k, top_p, chunk_size +): + """Test DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling.""" + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip( + f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}" + ) + cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True) + try: + actor_fqn = register_sampling_params_test_actor + sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) + builder = RayWorkerBuilder(actor_fqn, tp_size, sharding) + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + futures = worker_group.run_all_workers_single_data( + "test_distributed_logprob_with_sampling", + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + ) + results = ray.get(futures) + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + worker_group.shutdown(force=True) + finally: + cluster.shutdown() diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 5436159122..b65cb2d483 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -382,43 +382,6 @@ def test_vllm_missing_required_config_key(cluster): print(f"Successfully caught missing config key with error: {error_message}") -def test_vllm_top_p_top_k_validation(cluster): - """Test that top_p and top_k validation works correctly with threshold-based logic.""" - # Test that values above thresholds are allowed - config_above_thresholds = deepcopy(basic_vllm_test_config) - config_above_thresholds["top_p"] = 0.99 # Above TOP_P_THRESHOLD - config_above_thresholds["top_k"] = 8000 # Above TOP_K_THRESHOLD - - # Should not raise an error - try: - VllmGeneration(cluster, config_above_thresholds) - print("Successfully initialized with top_p=0.99 and top_k=8000") - except Exception as e: - pytest.fail(f"Should not raise error with values above thresholds: {e}") - - # Test that values below thresholds are rejected - config_below_thresholds = deepcopy(basic_vllm_test_config) - config_below_thresholds["top_p"] = 0.9 # Below TOP_P_THRESHOLD - - with pytest.raises(ValueError) as excinfo: - VllmGeneration(cluster, config_below_thresholds) - - error_message = str(excinfo.value) - assert "top_p sampling with values < 0.99 is not supported" in error_message - print(f"Successfully caught low top_p value with error: {error_message}") - - # Test that low top_k values are rejected - config_low_top_k = deepcopy(basic_vllm_test_config) - config_low_top_k["top_k"] = 7999 # Below TOP_K_THRESHOLD - - with pytest.raises(ValueError) as excinfo: - VllmGeneration(cluster, config_low_top_k) - - error_message = str(excinfo.value) - assert "top_k sampling with values < 8000 is not supported" in error_message - print(f"Successfully caught low top_k value with error: {error_message}") - - def test_vllm_policy_generation(policy, test_input_data, tokenizer): """Test vLLM policy generation capabilities.""" # Test generation diff --git a/tests/unit/models/generation/test_vllm_logprobs_mode.py b/tests/unit/models/generation/test_vllm_logprobs_mode.py index f43f9e323c..fb3bd591c4 100644 --- a/tests/unit/models/generation/test_vllm_logprobs_mode.py +++ b/tests/unit/models/generation/test_vllm_logprobs_mode.py @@ -167,7 +167,7 @@ def test_processed_logprobs_matches_manual_computation(): scaled_logits_batched = scaled_logits.unsqueeze(0).unsqueeze( 0 ) # [1, 1, vocab_size] - filtered_logits_batched = apply_top_k_top_p( + filtered_logits_batched, _ = apply_top_k_top_p( scaled_logits_batched, top_k=top_k, top_p=top_p ) filtered_logits = filtered_logits_batched.squeeze(0).squeeze(0) # [vocab_size] @@ -257,7 +257,7 @@ def test_apply_top_k_top_p_matches_vllm_upstream(top_k, top_p, test_name): print(f"Testing: {test_name}") # Our implementation: expects [batch, seq, vocab], takes scalar k/p - our_result = apply_top_k_top_p(logits_3d.clone(), top_k=top_k, top_p=top_p) + our_result, _ = apply_top_k_top_p(logits_3d.clone(), top_k=top_k, top_p=top_p) # vLLM upstream: expects [batch, vocab], takes tensor k/p with shape [batch] # Process each sequence position separately (vLLM doesn't batch over seq_len) From e3219c1b0454fe0e3f1d9c016d17c5c4eae7f146 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 2 Mar 2026 22:20:36 -0800 Subject: [PATCH 02/20] apply previous change in loss_function.py Signed-off-by: Yuki Huang Co-authored-by: Zhanda --- nemo_rl/algorithms/loss/loss_functions.py | 21 +++- nemo_rl/algorithms/loss/utils.py | 77 +++++++++++- nemo_rl/algorithms/loss/wrapper.py | 2 +- nemo_rl/distributed/model_utils.py | 112 ++++++------------ nemo_rl/models/automodel/train.py | 12 +- .../workers/dtensor_policy_worker_v2.py | 1 + 6 files changed, 138 insertions(+), 87 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 812d2917b6..18504dbaa9 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -204,6 +204,9 @@ def __call__( generation_logprobs = data["generation_logprobs"][:, 1:] if self.reference_policy_kl_penalty != 0: reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] + curr_logprobs_unfiltered = data.get( + "curr_logprobs_unfiltered", curr_logprobs + ) mask = token_mask * sample_mask.unsqueeze(-1) @@ -273,27 +276,39 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: + # When top-k/top-p filtering is enabled, we need special handling for KL: + # - reference_policy_logprobs is computed **without** filtering (see use_reference_model) + # - curr_logprobs is computed **with** filtering (for actor loss compatibility) + # - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs + # - For importance weights, we also use unfiltered curr_logprobs_for_kl since we're + # reweighting samples from π_gen_filtered to π_curr_unfiltered + + # On-policy KL approximation if self.use_on_policy_kl_approximation: # See: docs/guides/grpo.md#on-policy-kl-approximation kl_importance_weights = torch.exp( - curr_logprobs - generation_logprobs + curr_logprobs_unfiltered - generation_logprobs ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) else: - kl_importance_weights = torch.ones_like(curr_logprobs) + kl_importance_weights = torch.ones_like(curr_logprobs_unfiltered) + + # Compute KL loss kl = ( kl_importance_weights * self.reference_policy_kl_penalty * calculate_kl( - logprobs=curr_logprobs, + logprobs=curr_logprobs_unfiltered, logprobs_reference=reference_policy_logprobs, kl_type=self.reference_policy_kl_type, input_clamp_value=self.kl_input_clamp_value, output_clamp_value=self.kl_output_clamp_value, ) ) + + # Reduce KL loss if self.loss_type == LossType.TOKEN_LEVEL: kl = masked_mean( kl, mask, global_normalization_factor=global_valid_toks diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 359641ae09..d266d44224 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -22,6 +22,42 @@ get_distillation_topk_logprobs_from_logits, get_next_token_logprobs_from_logits, ) +from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + need_top_k_or_top_p_filtering, +) + + +def _mask_out_neg_inf_logprobs( + logprobs: torch.Tensor, mask: torch.Tensor, logprobs_name: str +) -> torch.Tensor: + """Mask out negative infinity log probabilities. + + Handling sampling mask mismatch: + vLLM samples token X from top-k/p filtered distribution -> generation_logprobs[X] is always finite (e.g., -5.41) + during training: policy computes logprobs with same top-k/p settings, but the distribution can be slightly different + token X may fall outside the training policy's top-k/p set -> curr_logprobs[X] = -inf, prev_logprobs[X] = -inf + Detect positions with -inf in any logprobs (generation_logprobs is always finite for valid tokens) + + Args: + logprobs: Log probabilities. + mask: Mask. + + Returns: + Masked log probabilities. + """ + is_neginf = torch.isinf(logprobs) + neginf_count = (is_neginf & mask.bool()).sum().item() + if neginf_count > 0: + print( + f"[WARNING]: {neginf_count}/{int(mask.sum().item())} valid tokens have -inf in {logprobs_name} " + "(policy top-k/top-p mismatch). Masking out these positions." + ) + + mask = mask * (~is_neginf).float() + logprobs = torch.where(mask.bool(), logprobs, 0.0) + + return logprobs def prepare_loss_input( @@ -31,21 +67,25 @@ def prepare_loss_input( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, -) -> dict[str, Any]: + sampling_params: Optional[TrainingSamplingParams] = None, +) -> tuple[dict[str, Any], BatchedDataDict[Any]]: """Prepare loss input for a loss function. Args: logits: Logits from the model. - data: Microbatch data. + data: Microbatch data. Will be updated if sampling_params is not None. loss_fn: Loss function. vocab_parallel_rank: Vocab parallel rank. vocab_parallel_group: Vocab parallel group. context_parallel_group: Context parallel group. + sampling_params: Sampling parameters. + Notes: vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. + sampling_params is only used for LossInputType.LOGPROB, and currently only supported for ClippedPGLossFn. Returns: - Loss input. + tuple(loss_input, maybe_updated_data) """ if loss_fn.input_type == LossInputType.LOGIT: loss_input = {"logits": logits} @@ -58,8 +98,37 @@ def prepare_loss_input( vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + sampling_params=sampling_params, ) + # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + # mask out negative infinity logprobs + mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) + logprobs = _mask_out_neg_inf_logprobs( + logprobs, mask[:, 1:], "curr_logprobs" + ) + data["prev_logprobs"] = _mask_out_neg_inf_logprobs( + data["prev_logprobs"], mask, "prev_logprobs" + ) + + # currently only used for ClippedPGLossFn + if ( + hasattr(loss_fn, "reference_policy_kl_penalty") + and loss_fn.reference_policy_kl_penalty != 0 + ): + data["curr_logprobs_unfiltered"] = get_next_token_logprobs_from_logits( + input_ids=data["input_ids"], + next_token_logits=logits, + seq_index=data.get("seq_index", None), + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + sampling_params=None, # no filtering + ) + loss_input = {"next_token_logprobs": logprobs} elif loss_fn.input_type == LossInputType.DISTILLATION: @@ -86,4 +155,4 @@ def prepare_loss_input( else: raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") - return loss_input + return loss_input, data diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index 39e8b12814..5e26d9b3e7 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -109,7 +109,7 @@ def __call__( ) # prepare data for loss function - loss_input = self.prepare_fn( + loss_input, unpadded_seq_data = self.prepare_fn( logits=next_token_logits_slice, data=unpadded_seq_data, loss_fn=self.loss_fn, diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 9020814bb2..259418a96e 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -1261,8 +1261,27 @@ def get_next_token_logprobs_from_logits( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, -): - """Computes next token log probabilities from logits.""" + sampling_params: Optional[TrainingSamplingParams] = None, +) -> torch.Tensor: + """Compute token log-probabilities from logits, handling parallel and non-parallel cases. + + This function handles three cases: + 1. Vocab parallel (Megatron-style): uses from_parallel_logits_to_logprobs + 2. DTensor: uses get_logprobs_from_vocab_parallel_logits + 3. Non-parallel: applies top-k/top-p filtering, log_softmax, and gather + + Args: + input_ids: Input token IDs of shape [batch_size, seq_len] + next_token_logits: Logits tensor of shape [batch_size, seq_len, vocab_size] + seq_index: Sequence index tensor for DTensor path + vocab_parallel_rank: Rank in the vocab parallel group (required if vocab_parallel_group is provided) + vocab_parallel_group: Process group for vocab parallelism + context_parallel_group: Process group for context parallelism + sampling_params: Sampling parameters for top-k/top-p filtering + + Returns: + Token log-probabilities of shape [batch_size, seq_len - 1] + """ next_token_logits = next_token_logits.to(torch.float32) if vocab_parallel_group is not None: @@ -1277,16 +1296,29 @@ def get_next_token_logprobs_from_logits( tp_group=vocab_parallel_group, inference_only=False, cp_group=context_parallel_group, + sampling_params=sampling_params, ) # slice off to the correct length to remove potential CP padding logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, input_ids, seq_index=seq_index + next_token_logits, + input_ids, + seq_index=seq_index, + sampling_params=sampling_params, ) + else: # Remove last position's logits next_token_logits_wo_last = next_token_logits[:, :-1] + # Apply top-k and top-p filtering + next_token_logits_wo_last, _ = apply_top_k_top_p( + next_token_logits_wo_last, + top_k=sampling_params.top_k if sampling_params is not None else None, + top_p=sampling_params.top_p if sampling_params is not None else 1.0, + ) + # Compute logprobs next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits_wo_last, dim=-1 ) @@ -1705,80 +1737,6 @@ def backward( return grad_input, None, None, None -def compute_logprobs_from_logits( - next_token_logits: torch.Tensor, - input_ids: torch.Tensor, - vocab_parallel_rank: Optional[int], - vocab_parallel_group: Optional[torch.distributed.ProcessGroup], - context_parallel_group: Optional[torch.distributed.ProcessGroup], - seq_index: Optional[torch.Tensor], - sampling_params: Optional[TrainingSamplingParams], -) -> torch.Tensor: - """Compute token log-probabilities from logits, handling parallel and non-parallel cases. - - This function handles three cases: - 1. Vocab parallel (Megatron-style): uses from_parallel_logits_to_logprobs - 2. DTensor: uses get_logprobs_from_vocab_parallel_logits - 3. Non-parallel: applies top-k/top-p filtering, log_softmax, and gather - - Args: - next_token_logits: Logits tensor of shape [batch_size, seq_len, vocab_size] - input_ids: Input token IDs of shape [batch_size, seq_len] - vocab_parallel_rank: Rank in the vocab parallel group (required if vocab_parallel_group is provided) - vocab_parallel_group: Process group for vocab parallelism - context_parallel_group: Process group for context parallelism - seq_index: Sequence index tensor for DTensor path - sampling_params: Sampling parameters for top-k/top-p filtering - - Returns: - Token log-probabilities of shape [batch_size, seq_len - 1] - """ - next_token_logits = next_token_logits.to(torch.float32) - - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - input_ids, - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - sampling_params=sampling_params, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : input_ids.shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, - input_ids, - seq_index=seq_index, - sampling_params=sampling_params, - ) - else: - next_token_logits_wo_last = next_token_logits[ - :, :-1 - ] # Remove last position's logits - # Apply top-k and top-p filtering - next_token_logits_wo_last, _ = apply_top_k_top_p( - next_token_logits_wo_last, - top_k=sampling_params.top_k if sampling_params is not None else None, - top_p=sampling_params.top_p if sampling_params is not None else 1.0, - ) - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits_wo_last, dim=-1 - ) - next_tokens = input_ids[:, 1:].cuda() # Skip first token - token_logprobs = next_token_logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - - return token_logprobs - - def all_to_all_vp2sq( vocab_parallel_logits: torch.Tensor, tp_group: torch.distributed.ProcessGroup, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 32e3386ab6..12431f5340 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -24,6 +24,7 @@ """ from collections import defaultdict +from functools import partial from typing import Any, Callable, Iterator, Optional, Tuple, Union import torch @@ -41,6 +42,7 @@ ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.utils import TrainingSamplingParams # Union type for any post-processing function PostProcessingFunction = Union[ @@ -460,6 +462,7 @@ def __init__( cp_size: int, dp_size: int, enable_seq_packing: bool = False, + sampling_params: Optional[TrainingSamplingParams] = None, ): """Initialize LossPostProcessor. @@ -472,6 +475,7 @@ def __init__( cp_size: Context parallel size dp_size: Data parallel size enable_seq_packing: Whether sequence packing is enabled + sampling_params: Sampling parameters """ self.loss_fn: LossFunction = loss_fn self.cfg: PolicyConfig = cfg @@ -481,6 +485,7 @@ def __init__( self.cp_size = cp_size self.dp_size = dp_size self.enable_seq_packing = enable_seq_packing + self.sampling_params = sampling_params def __call__( self, @@ -514,10 +519,13 @@ def __call__( ) # Wrap loss function for sequence packing if needed + wrapped_prepare_loss_input = partial( + prepare_loss_input, sampling_params=self.sampling_params + ) if self.enable_seq_packing: loss_fn = SequencePackingLossWrapper( loss_fn=self.loss_fn, - prepare_fn=prepare_loss_input, + prepare_fn=wrapped_prepare_loss_input, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) @@ -528,7 +536,7 @@ def __call__( global_valid_toks, ) else: - loss_input = prepare_loss_input(logits, mb, self.loss_fn) + loss_input, mb = wrapped_prepare_loss_input(logits, mb, self.loss_fn) loss, loss_metrics = self.loss_fn( data=mb, global_valid_seqs=global_valid_seqs, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index b9ca7c39c8..ed5d8e4acf 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -398,6 +398,7 @@ def train( cp_size=self.cp_size, dp_size=self.dp_size, enable_seq_packing=self.enable_seq_packing, + sampling_params=self.sampling_params, ) # Create train context factory From e087ee1e733c639c148666d41b900edf57bcd717 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 00:16:33 -0800 Subject: [PATCH 03/20] split to logits_sampling_utils.py Signed-off-by: Yuki Huang --- nemo_rl/algorithms/logits_sampling_utils.py | 278 ++++++++++++++++++ nemo_rl/algorithms/loss/utils.py | 10 +- nemo_rl/distributed/model_utils.py | 2 +- nemo_rl/models/automodel/train.py | 2 +- nemo_rl/models/policy/utils.py | 262 ----------------- .../policy/workers/dtensor_policy_worker.py | 10 +- .../workers/dtensor_policy_worker_v2.py | 12 +- tests/unit/distributed/test_model_utils.py | 4 +- .../generation/test_vllm_logprobs_mode.py | 4 +- 9 files changed, 300 insertions(+), 284 deletions(-) create mode 100644 nemo_rl/algorithms/logits_sampling_utils.py diff --git a/nemo_rl/algorithms/logits_sampling_utils.py b/nemo_rl/algorithms/logits_sampling_utils.py new file mode 100644 index 0000000000..54e9ceb992 --- /dev/null +++ b/nemo_rl/algorithms/logits_sampling_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +import torch + +# Default chunk size for top-k/top-p filtering. +# The sort operation in top-p filtering is memory intensive because it creates +# intermediate tensors of shape [bsz, seq_len, vocab_size] for both sorted values +# and indices. For large vocab sizes (e.g., 152K) and long sequences (e.g., 32K), +# this can cause OOM. Chunking along the sequence dimension reduces peak memory. +# Different chunk sizes have minor performance differences. +TOP_K_TOP_P_CHUNK_SIZE: int = 256 + + +def need_top_k_filtering(top_k: int | None) -> bool: + """Check if top-k filtering is needed.""" + return top_k is not None and top_k != -1 + + +def need_top_p_filtering(top_p: float | None) -> bool: + """Check if top-p filtering is needed.""" + return top_p is not None and top_p != 1.0 + + +def need_top_k_or_top_p_filtering(top_k: int | None, top_p: float | None) -> bool: + """Check if top-k or top-p filtering is needed.""" + return need_top_k_filtering(top_k) or need_top_p_filtering(top_p) + + +@dataclass +class TrainingSamplingParams: + """Training-specific sampling parameters to match generation parameters. + + Used to ensure consistency between training and inference by applying the same sampling strategy during + logprob computation. Not directly using vLLM's SamplingParams class to avoid dependency on vLLM in this env. + + Attributes: + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + temperature: Temperature for scaling logits (default: 1.0) + """ + + top_k: int | None = None + top_p: float = 1.0 + temperature: float = 1.0 + + +@torch.no_grad() +def _apply_top_k_only_fn( + logits: torch.Tensor, + top_k: int | None, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k mask to the logits. + + Simplified version of VLLM's implementation for scalar parameters. + This implementation doesn't involve sorting the entire vocab. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed. + """ + if not need_top_k_filtering(top_k): + return logits, None + + # Get top-k values and create mask + assert top_k is not None # Type narrowing + top_k_values, _ = torch.topk(logits, top_k, dim=-1) + threshold = top_k_values[..., -1:].expand_as(logits) + keep_mask = logits >= threshold + + # Apply mask: keep top-k values, set others to -inf + logits = torch.where( + keep_mask, + logits, + torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), + ) + return logits, keep_mask + + +@torch.no_grad() +def _apply_top_k_top_p_fn( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with chunking for memory efficiency. + + The sort operation in top-p filtering is memory intensive because it creates + intermediate tensors of shape [num_tokens, vocab_size] for both sorted values + and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. + This function flattens the input to 2D and processes in chunks along the token + dimension (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] + or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens + chunk_size: Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates + tokens to be kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_p_filtering(top_p): + if not need_top_k_filtering(top_k): + return logits, None + # Avoid sorting vocab for top-k only case + filtered_logits, top_k_keep_mask = _apply_top_k_only_fn(logits, top_k) + return filtered_logits, top_k_keep_mask + + # Save original shape and flatten to 2D for consistent chunking + original_shape = logits.shape + vocab_size = logits.shape[-1] + logits = logits.view(-1, vocab_size) # [*, vocab_size] -> [num_tokens, vocab_size] + num_tokens = logits.shape[0] + + chunk_size = chunk_size if chunk_size is not None else num_tokens + + # Pre-allocate output tensors + filtered_logits = torch.empty_like(logits) + keep_mask = torch.empty( + num_tokens, vocab_size, dtype=torch.bool, device=logits.device + ) + + for start_idx in range(0, num_tokens, chunk_size): + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_logits = logits[start_idx:end_idx, :] + + # Sort this chunk + logits_sort, logits_idx = chunk_logits.sort(dim=-1, descending=False) + top_k_keep_mask_chunk = None + + if need_top_k_filtering(top_k): + assert top_k is not None # Type narrowing + # Apply top-k first + top_k_index = logits_sort.size(-1) - top_k + index_tensor = torch.full( + logits_sort.shape[:-1], + top_k_index, + device=logits_sort.device, + dtype=torch.long, + ) + top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) + top_k_keep_mask_chunk = logits_sort >= top_k_threshold + logits_sort.masked_fill_(~top_k_keep_mask_chunk, -float("inf")) + + # Apply top-p + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_p_keep_mask_chunk = probs_sum > 1 - top_p + # at least one + top_p_keep_mask_chunk[..., -1] = True + logits_sort.masked_fill_(~top_p_keep_mask_chunk, -float("inf")) + + # Scatter back to original order + chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + if top_k_keep_mask_chunk is not None: + chunk_mask = torch.logical_and(top_k_keep_mask_chunk, top_p_keep_mask_chunk) + else: + chunk_mask = top_p_keep_mask_chunk + chunk_mask = chunk_mask.scatter(dim=-1, index=logits_idx, src=chunk_mask) + + # Store results + filtered_logits[start_idx:end_idx, :] = chunk_filtered + keep_mask[start_idx:end_idx, :] = chunk_mask + + # Restore original shape + filtered_logits = filtered_logits.view(original_shape) + keep_mask = keep_mask.view(original_shape) + + return filtered_logits, keep_mask + + +class ApplyTopKTopP(torch.autograd.Function): + """Autograd function for top-k and top-p filtering with proper gradient handling.""" + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx, + logits: torch.Tensor, + top_k: Optional[int], + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k/top-p filtering and save masks for backward. + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + """ + filtered_logits, keep_mask = _apply_top_k_top_p_fn( + logits, top_k, top_p, chunk_size + ) + + # Save masks for backward pass + ctx.save_for_backward(keep_mask) + + return filtered_logits, keep_mask + + @staticmethod + def backward(ctx, *grad_outputs: torch.Tensor): + """Backward pass: mask out gradients for filtered tokens.""" + grad_filtered_logits = grad_outputs[0] + (keep_mask,) = ctx.saved_tensors + + # Apply masks to gradients - masked out tokens should not receive gradients + if keep_mask is not None: + grad_filtered_logits = grad_filtered_logits.masked_fill(~keep_mask, 0.0) + + return grad_filtered_logits, None, None, None + + +def apply_top_k_top_p( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with proper gradient handling. + + Simplified version of VLLM's implementation for scalar parameters. + + When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. + Processing is done in chunks (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_k_or_top_p_filtering(top_k, top_p): + return logits, None + return ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index d266d44224..7d65440ed8 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -16,16 +16,16 @@ import torch +from nemo_rl.algorithms.logits_sampling_utils import ( + TrainingSamplingParams, + need_top_k_or_top_p_filtering, +) from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( get_distillation_topk_logprobs_from_logits, get_next_token_logprobs_from_logits, ) -from nemo_rl.models.policy.utils import ( - TrainingSamplingParams, - need_top_k_or_top_p_filtering, -) def _mask_out_neg_inf_logprobs( @@ -114,7 +114,7 @@ def prepare_loss_input( data["prev_logprobs"], mask, "prev_logprobs" ) - # currently only used for ClippedPGLossFn + # compute unfiltered logprobs for reference policy KL penalty if ( hasattr(loss_fn, "reference_policy_kl_penalty") and loss_fn.reference_policy_kl_penalty != 0 diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 259418a96e..c71de218a0 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -17,7 +17,7 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor -from nemo_rl.models.policy.utils import ( +from nemo_rl.algorithms.logits_sampling_utils import ( TrainingSamplingParams, apply_top_k_top_p, need_top_k_or_top_p_filtering, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 12431f5340..5bee096798 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -32,6 +32,7 @@ from torch import nn from torch.distributed.tensor import DTensor, Shard +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -42,7 +43,6 @@ ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.utils import TrainingSamplingParams # Union type for any post-processing function PostProcessingFunction = Union[ diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 97f011b704..bbd2e6d2f6 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -15,7 +15,6 @@ import gc import os import traceback -from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, cast @@ -76,15 +75,6 @@ } -# Default chunk size for top-k/top-p filtering. -# The sort operation in top-p filtering is memory intensive because it creates -# intermediate tensors of shape [bsz, seq_len, vocab_size] for both sorted values -# and indices. For large vocab sizes (e.g., 152K) and long sequences (e.g., 32K), -# this can cause OOM. Chunking along the sequence dimension reduces peak memory. -# Different chunk sizes have minor performance differences. -TOP_K_TOP_P_CHUNK_SIZE: int = 256 - - class IPCProtocol(Enum): """IPC protocol constants for ZMQ weight streaming.""" @@ -92,258 +82,6 @@ class IPCProtocol(Enum): ACK = "ack" -def need_top_k_filtering(top_k: int | None) -> bool: - """Check if top-k filtering is needed.""" - return top_k is not None and top_k != -1 - - -def need_top_p_filtering(top_p: float | None) -> bool: - """Check if top-p filtering is needed.""" - return top_p is not None and top_p != 1.0 - - -def need_top_k_or_top_p_filtering(top_k: int | None, top_p: float | None) -> bool: - """Check if top-k or top-p filtering is needed.""" - return need_top_k_filtering(top_k) or need_top_p_filtering(top_p) - - -@dataclass -class TrainingSamplingParams: - """Training-specific sampling parameters to match generation parameters. - - Used to ensure consistency between training and inference by applying the same sampling strategy during - logprob computation. Not directly using vLLM's SamplingParams class to avoid dependency on vLLM in this env. - - Attributes: - top_k: Top-k filtering parameter (None or -1 to disable) - top_p: Top-p filtering parameter (1.0 to disable) - temperature: Temperature for scaling logits (default: 1.0) - """ - - top_k: int | None = None - top_p: float = 1.0 - temperature: float = 1.0 - - -@torch.no_grad() -def _apply_top_k_only_fn( - logits: torch.Tensor, - top_k: int | None, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """Apply top-k mask to the logits. - - Simplified version of VLLM's implementation for scalar parameters. - This implementation doesn't involve sorting the entire vocab. - - Based on VLLM's implementation: - https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py - SPDX-License-Identifier: Apache-2.0 - Copyright contributors to the vLLM project - - Args: - logits: Input logits tensor of shape [*, vocab_size]. - top_k: Top-k sampling parameter. - - Returns: - filtered_logits: Filtered logits tensor with the same shape as input logits. - keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be - kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed. - """ - if not need_top_k_filtering(top_k): - return logits, None - - # Get top-k values and create mask - assert top_k is not None # Type narrowing - top_k_values, _ = torch.topk(logits, top_k, dim=-1) - threshold = top_k_values[..., -1:].expand_as(logits) - keep_mask = logits >= threshold - - # Apply mask: keep top-k values, set others to -inf - logits = torch.where( - keep_mask, - logits, - torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), - ) - return logits, keep_mask - - -@torch.no_grad() -def _apply_top_k_top_p_fn( - logits: torch.Tensor, - top_k: int | None, - top_p: float, - chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """Apply top-k and top-p masks to the logits with chunking for memory efficiency. - - The sort operation in top-p filtering is memory intensive because it creates - intermediate tensors of shape [num_tokens, vocab_size] for both sorted values - and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. - This function flattens the input to 2D and processes in chunks along the token - dimension (controlled by chunk_size) to reduce peak memory. - - Based on VLLM's implementation: - https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py - SPDX-License-Identifier: Apache-2.0 - Copyright contributors to the vLLM project - - Args: - logits: Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] - or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing. - top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. - top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens - chunk_size: Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE. - - Returns: - filtered_logits: Filtered logits tensor with the same shape as input logits. - keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates - tokens to be kept, 0 (False) indicates tokens to be masked. - """ - if not need_top_p_filtering(top_p): - if not need_top_k_filtering(top_k): - return logits, None - # Avoid sorting vocab for top-k only case - filtered_logits, top_k_keep_mask = _apply_top_k_only_fn(logits, top_k) - return filtered_logits, top_k_keep_mask - - # Save original shape and flatten to 2D for consistent chunking - original_shape = logits.shape - vocab_size = logits.shape[-1] - logits = logits.view(-1, vocab_size) # [*, vocab_size] -> [num_tokens, vocab_size] - num_tokens = logits.shape[0] - - chunk_size = chunk_size if chunk_size is not None else num_tokens - - # Pre-allocate output tensors - filtered_logits = torch.empty_like(logits) - keep_mask = torch.empty( - num_tokens, vocab_size, dtype=torch.bool, device=logits.device - ) - - for start_idx in range(0, num_tokens, chunk_size): - end_idx = min(start_idx + chunk_size, num_tokens) - chunk_logits = logits[start_idx:end_idx, :] - - # Sort this chunk - logits_sort, logits_idx = chunk_logits.sort(dim=-1, descending=False) - top_k_keep_mask_chunk = None - - if need_top_k_filtering(top_k): - assert top_k is not None # Type narrowing - # Apply top-k first - top_k_index = logits_sort.size(-1) - top_k - index_tensor = torch.full( - logits_sort.shape[:-1], - top_k_index, - device=logits_sort.device, - dtype=torch.long, - ) - top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) - top_k_keep_mask_chunk = logits_sort >= top_k_threshold - logits_sort.masked_fill_(~top_k_keep_mask_chunk, -float("inf")) - - # Apply top-p - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_p_keep_mask_chunk = probs_sum > 1 - top_p - # at least one - top_p_keep_mask_chunk[..., -1] = True - logits_sort.masked_fill_(~top_p_keep_mask_chunk, -float("inf")) - - # Scatter back to original order - chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - if top_k_keep_mask_chunk is not None: - chunk_mask = torch.logical_and(top_k_keep_mask_chunk, top_p_keep_mask_chunk) - else: - chunk_mask = top_p_keep_mask_chunk - chunk_mask = chunk_mask.scatter(dim=-1, index=logits_idx, src=chunk_mask) - - # Store results - filtered_logits[start_idx:end_idx, :] = chunk_filtered - keep_mask[start_idx:end_idx, :] = chunk_mask - - # Restore original shape - filtered_logits = filtered_logits.view(original_shape) - keep_mask = keep_mask.view(original_shape) - - return filtered_logits, keep_mask - - -class ApplyTopKTopP(torch.autograd.Function): - """Autograd function for top-k and top-p filtering with proper gradient handling.""" - - @staticmethod - def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class - ctx, - logits: torch.Tensor, - top_k: Optional[int], - top_p: float, - chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Apply top-k/top-p filtering and save masks for backward. - - Args: - logits: Input logits tensor of shape [*, vocab_size]. - top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. - top_p: Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. - chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. - """ - filtered_logits, keep_mask = _apply_top_k_top_p_fn( - logits, top_k, top_p, chunk_size - ) - - # Save masks for backward pass - ctx.save_for_backward(keep_mask) - - return filtered_logits, keep_mask - - @staticmethod - def backward(ctx, *grad_outputs: torch.Tensor): - """Backward pass: mask out gradients for filtered tokens.""" - grad_filtered_logits = grad_outputs[0] - (keep_mask,) = ctx.saved_tensors - - # Apply masks to gradients - masked out tokens should not receive gradients - if keep_mask is not None: - grad_filtered_logits = grad_filtered_logits.masked_fill(~keep_mask, 0.0) - - return grad_filtered_logits, None, None, None - - -def apply_top_k_top_p( - logits: torch.Tensor, - top_k: int | None, - top_p: float, - chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """Apply top-k and top-p masks to the logits with proper gradient handling. - - Simplified version of VLLM's implementation for scalar parameters. - - When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. - Processing is done in chunks (controlled by chunk_size) to reduce peak memory. - - Based on VLLM's implementation: - https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py - SPDX-License-Identifier: Apache-2.0 - Copyright contributors to the vLLM project - - Args: - logits: Input logits tensor of shape [*, vocab_size]. - top_k: Top-k sampling parameter. Set to -1 to consider all tokens. - top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. - chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. - - Returns: - filtered_logits: Filtered logits tensor with the same shape as input logits. - keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be - kept, 0 (False) indicates tokens to be masked. - """ - if not need_top_k_or_top_p_filtering(top_k, top_p): - return logits, None - return ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) - - def resolve_model_class(model_name: str) -> Any: """Resolve the appropriate model class for a given model name.""" if NEMO_AUTOMODEL_AVAILABLE: diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 08a7eaa83a..df563456ef 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -46,6 +46,12 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM +from nemo_rl.algorithms.logits_sampling_utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + need_top_k_filtering, + need_top_p_filtering, +) from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -71,12 +77,8 @@ ScoreOutputSpec, ) from nemo_rl.models.policy.utils import ( - TrainingSamplingParams, - apply_top_k_top_p, configure_dynamo_cache, get_runtime_env_for_policy_worker, - need_top_k_filtering, - need_top_p_filtering, resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index ed5d8e4acf..261682f873 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -39,6 +39,11 @@ AutoTokenizer, ) +from nemo_rl.algorithms.logits_sampling_utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + need_top_k_or_top_p_filtering, +) from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( @@ -67,12 +72,7 @@ LogprobOutputSpec, ScoreOutputSpec, ) -from nemo_rl.models.policy.utils import ( - TrainingSamplingParams, - apply_top_k_top_p, - get_runtime_env_for_policy_worker, - need_top_k_or_top_p_filtering, -) +from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import ( apply_torch_aten_alias_tensor_patch, diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 3b7f7ecc8f..aad98d1da4 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -17,6 +17,7 @@ import ray import torch +from nemo_rl.algorithms.logits_sampling_utils import apply_top_k_top_p from nemo_rl.distributed.model_utils import ( ChunkedDistributedGatherLogprob, ChunkedDistributedLogprob, @@ -36,9 +37,6 @@ ) from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup -from nemo_rl.models.policy.utils import ( - apply_top_k_top_p, -) @ray.remote(num_gpus=1) diff --git a/tests/unit/models/generation/test_vllm_logprobs_mode.py b/tests/unit/models/generation/test_vllm_logprobs_mode.py index fb3bd591c4..1cf0310157 100644 --- a/tests/unit/models/generation/test_vllm_logprobs_mode.py +++ b/tests/unit/models/generation/test_vllm_logprobs_mode.py @@ -59,7 +59,7 @@ def test_processed_logprobs_matches_manual_computation(): from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import LLM, SamplingParams - from nemo_rl.models.policy.utils import apply_top_k_top_p + from nemo_rl.algorithms.logits_sampling_utils import apply_top_k_top_p # Use a small model for fast testing model_name = "facebook/opt-125m" @@ -243,7 +243,7 @@ def test_apply_top_k_top_p_matches_vllm_upstream(top_k, top_p, test_name): apply_top_k_top_p as vllm_apply_top_k_top_p, ) - from nemo_rl.models.policy.utils import apply_top_k_top_p + from nemo_rl.algorithms.logits_sampling_utils import apply_top_k_top_p # Test configuration batch_size = 4 From 961f946d1249718b969f8cdbe9243e8daa0b3d97 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 00:21:23 -0800 Subject: [PATCH 04/20] move TrainingSamplingParams to setup Signed-off-by: Yuki Huang --- nemo_rl/models/automodel/config.py | 5 +++++ nemo_rl/models/automodel/setup.py | 12 ++++++++++++ .../policy/workers/dtensor_policy_worker_v2.py | 1 + 3 files changed, 18 insertions(+) diff --git a/nemo_rl/models/automodel/config.py b/nemo_rl/models/automodel/config.py index 0e434268b7..37826e95b7 100644 --- a/nemo_rl/models/automodel/config.py +++ b/nemo_rl/models/automodel/config.py @@ -19,6 +19,8 @@ import torch from nemo_automodel.components._peft.lora import PeftConfig +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams + class RuntimeConfig(NamedTuple): """Runtime configuration for model training and inference. @@ -48,6 +50,9 @@ class RuntimeConfig(NamedTuple): # Generation configuration is_generation_colocated: Optional[bool] + # Sampling parameters + sampling_params: Optional[TrainingSamplingParams] + # Reward model flag is_reward_model: bool diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index 4347fca5a3..a5c50433db 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -36,6 +36,7 @@ from transformers import AutoConfig, AutoProcessor, AutoTokenizer, PreTrainedModel from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.models.automodel.config import ModelAndOptimizerState, RuntimeConfig from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.utils import configure_dynamo_cache, resolve_model_class @@ -90,6 +91,16 @@ def validate_and_prepare_config( raise ValueError(f"Unknown precision: {precision}") dtype = STRING_TO_DTYPE[precision] + # Sampling parameters configuration + sampling_params = None + if "generation" in config and config["generation"] is not None: + generation_cfg = config["generation"] + sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + # Get other configuration values cpu_offload = config["dtensor_cfg"]["cpu_offload"] offload_optimizer_for_logprob = config.get("offload_optimizer_for_logprob", False) @@ -193,6 +204,7 @@ def validate_and_prepare_config( cpu_offload=cpu_offload, offload_optimizer_for_logprob=offload_optimizer_for_logprob, is_generation_colocated=is_generation_colocated, + sampling_params=sampling_params, is_reward_model=is_reward_model, ) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 261682f873..a8d951338d 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -333,6 +333,7 @@ def __init__( self.cpu_offload, self.offload_optimizer_for_logprob, self.is_generation_colocated, + self.sampling_params, _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config From e4658a94b14fd3f3819d5eb62ea63f2313dfcc74 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 01:12:47 -0800 Subject: [PATCH 05/20] add topp topk in LogprobsPostProcessor Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/loss_functions.py | 2 +- nemo_rl/models/automodel/train.py | 60 ++++++++++++++++--- nemo_rl/models/megatron/train.py | 3 + .../workers/dtensor_policy_worker_v2.py | 41 ++----------- 4 files changed, 60 insertions(+), 46 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 18504dbaa9..7244d868c7 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -278,7 +278,7 @@ def __call__( if self.reference_policy_kl_penalty != 0: # When top-k/top-p filtering is enabled, we need special handling for KL: # - reference_policy_logprobs is computed **without** filtering (see use_reference_model) - # - curr_logprobs is computed **with** filtering (for actor loss compatibility) + # - curr_logprobs/prev_logprobs are computed **with** filtering (for actor loss compatibility) # - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs # - For importance weights, we also use unfiltered curr_logprobs_for_kl since we're # reweighting samples from π_gen_filtered to π_curr_unfiltered diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 5bee096798..4ecf3ae4f5 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -32,7 +32,11 @@ from torch import nn from torch.distributed.tensor import DTensor, Shard -from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams +from nemo_rl.algorithms.logits_sampling_utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + need_top_k_or_top_p_filtering, +) from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -124,20 +128,42 @@ def extract_logits( def apply_temperature_scaling( - logits: torch.Tensor, - cfg: PolicyConfig, + logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams] ) -> torch.Tensor: """Apply temperature scaling to logits. Args: logits: Logits tensor to scale - cfg: Configuration dictionary containing generation settings + sampling_params: Sampling parameters Returns: torch.Tensor: Temperature-scaled logits """ - if "generation" in cfg and cfg["generation"] is not None: - logits.div_(cfg["generation"]["temperature"]) + if sampling_params is not None and sampling_params.temperature != 1.0: + logits.div_(sampling_params.temperature) + return logits + + +def apply_top_k_top_p_filtering_for_local_logits( + logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams] +) -> torch.Tensor: + """Apply top-k and top-p filtering to the non-distributed logits. + + Args: + logits: Logits tensor to filter + sampling_params: Sampling parameters + + Returns: + torch.Tensor: Filtered logits + """ + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + logits, _ = apply_top_k_top_p( + logits, + top_k=sampling_params.top_k, + top_p=sampling_params.top_p, + ) return logits @@ -233,7 +259,7 @@ def prepare_data_for_cp( def forward_with_post_processing_fn( model: nn.Module, - cfg: PolicyConfig, + sampling_params: TrainingSamplingParams, post_processing_fn: PostProcessingFunction, processed_mb: ProcessedMicrobatch, is_reward_model: bool = False, @@ -253,7 +279,7 @@ def forward_with_post_processing_fn( Args: model: The model to run forward pass on - cfg: Configuration dictionary + sampling_params: Sampling parameters post_processing_fn: Post-processing function to apply to the logits processed_mb: Pre-fetched ProcessedMicrobatch containing data and processed inputs is_reward_model: Whether this is a reward model @@ -290,7 +316,10 @@ def forward_with_post_processing_fn( post_processing_fn, (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor), ): - logits = apply_temperature_scaling(logits, cfg) + # Temperature scaling is element-wise, directly applying it here. + # Other sampling parameters like top-k and top-p need the logits from whole vocabulary, + # so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor). + logits = apply_temperature_scaling(logits, sampling_params) # Apply the post-processing function directly based on type if isinstance(post_processing_fn, LossPostProcessor): @@ -558,6 +587,7 @@ def __init__( tp_mesh: Any, cp_size: int, enable_seq_packing: bool = False, + sampling_params: Optional[TrainingSamplingParams] = None, ): """Initialize LogprobsPostProcessor. @@ -568,6 +598,7 @@ def __init__( tp_mesh: Tensor parallel mesh cp_size: Context parallel size enable_seq_packing: Whether sequence packing is enabled + sampling_params: Sampling parameters """ self.cfg = cfg self.device_mesh = device_mesh @@ -575,6 +606,7 @@ def __init__( self.tp_mesh = tp_mesh self.cp_size = cp_size self.enable_seq_packing = enable_seq_packing + self.sampling_params = sampling_params self.logprob_chunk_size = cfg.get("logprob_chunk_size", None) def __call__( @@ -627,17 +659,21 @@ def __call__( input_ids_dtensor, seq_index_tensor, chunk_size=self.logprob_chunk_size, + sampling_params=self.sampling_params, # top-k and top-p filtering ) assert token_logprobs.shape[1] == seq_len - 1 else: if isinstance(logits, DTensor): + # DTensor path with TP sharding token_logprobs = get_logprobs_from_vocab_parallel_logits( logits, processed_inputs.input_ids, chunk_size=self.logprob_chunk_size, + sampling_params=self.sampling_params, # top-k and top-p filtering ) else: + # Non-DTensor path (no TP sharding) token_logprobs = self._compute_local_logprobs( logits, processed_inputs.input_ids ) @@ -703,12 +739,18 @@ def _compute_local_logprobs( (chunk_idx + 1) * self.logprob_chunk_size, ) chunk_logits = logits[:, chunk_start:chunk_end, :].to(torch.float32) + chunk_logits = apply_top_k_top_p_filtering_for_local_logits( + chunk_logits, self.sampling_params + ) log_probs = torch.nn.functional.log_softmax(chunk_logits, dim=-1) chunked_log_probs.append(log_probs) log_probs = torch.cat(chunked_log_probs, dim=1) del chunked_log_probs else: logits = logits.to(torch.float32) + logits = apply_top_k_top_p_filtering_for_local_logits( + logits, self.sampling_params + ) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # Extract logprobs for each token in the sequence by gathering the logprob diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 5535c9025c..19e02bde27 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -184,6 +184,9 @@ def forward_with_post_processing_fn( post_processing_fn, (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor), ): + # Temperature scaling is element-wise, directly applying it here. + # Other sampling parameters like top-k and top-p need the logits from whole vocabulary, + # so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor). apply_temperature_scaling(output_tensor, cfg) # Use type checking to dispatch to the correct post-processing method diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index a8d951338d..280ad12dca 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -39,11 +39,7 @@ AutoTokenizer, ) -from nemo_rl.algorithms.logits_sampling_utils import ( - TrainingSamplingParams, - apply_top_k_top_p, - need_top_k_or_top_p_filtering, -) +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( @@ -242,16 +238,6 @@ def __init__( # Initialize checkpoint manager self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None - if "generation" in self.cfg and self.cfg["generation"] is not None: - generation_cfg = self.cfg["generation"] - self.sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), - ) - else: - self.sampling_params = None - # Validate configuration and prepare runtime settings runtime_config = validate_and_prepare_config( config=config, @@ -337,24 +323,6 @@ def __init__( _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config - def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: - if self.sampling_params is not None and self.sampling_params.temperature != 1.0: - logits.div_(self.sampling_params.temperature) - return logits - - def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: - """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" - sampling_params = self.sampling_params - if sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ): - logits, _ = apply_top_k_top_p( - logits, - top_k=sampling_params.top_k, - top_p=sampling_params.top_p, - ) - return logits - @wrap_with_nvtx_name("dtensor_policy_worker_v2/train") def train( self, @@ -574,6 +542,7 @@ def get_logprobs( tp_mesh=self.tp_mesh, cp_size=self.cp_size, enable_seq_packing=self.enable_seq_packing, + sampling_params=self.sampling_params, ) with torch.no_grad(): @@ -602,7 +571,7 @@ def get_logprobs( # Use forward_with_post_processing_fn for forward pass and post-processing token_logprobs, _metrics, _ = forward_with_post_processing_fn( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, post_processing_fn=logprobs_post_processor, processed_mb=processed_mb, is_reward_model=False, @@ -671,7 +640,7 @@ def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: # Use forward_with_post_processing_fn for forward pass and post-processing rm_scores, _metrics, _ = forward_with_post_processing_fn( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, post_processing_fn=score_post_processor, processed_mb=processed_mb, is_reward_model=True, @@ -761,7 +730,7 @@ def get_topk_logits( # Use forward_with_post_processing_fn for forward pass and post-processing (vals, idx), _metrics, _ = forward_with_post_processing_fn( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, post_processing_fn=topk_post_processor, processed_mb=processed_mb, is_reward_model=False, From 83899a71bf138b85304b8e184e6f0e1c3504f400 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 01:53:30 -0800 Subject: [PATCH 06/20] update megatron Signed-off-by: Yuki Huang --- nemo_rl/models/automodel/train.py | 9 ++-- nemo_rl/models/megatron/config.py | 3 ++ nemo_rl/models/megatron/setup.py | 13 +++++- nemo_rl/models/megatron/train.py | 40 ++++++++++-------- .../workers/dtensor_policy_worker_v2.py | 17 +++++--- .../policy/workers/megatron_policy_worker.py | 41 +++++++++++++++---- 6 files changed, 89 insertions(+), 34 deletions(-) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 4ecf3ae4f5..42fb1d1203 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -547,14 +547,15 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) - # Wrap loss function for sequence packing if needed - wrapped_prepare_loss_input = partial( + # Wrap prepare_loss_input with sampling_params + prepare_loss_input_wrapped = partial( prepare_loss_input, sampling_params=self.sampling_params ) + # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn = SequencePackingLossWrapper( loss_fn=self.loss_fn, - prepare_fn=wrapped_prepare_loss_input, + prepare_fn=prepare_loss_input_wrapped, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) @@ -565,7 +566,7 @@ def __call__( global_valid_toks, ) else: - loss_input, mb = wrapped_prepare_loss_input(logits, mb, self.loss_fn) + loss_input, mb = prepare_loss_input_wrapped(logits, mb, self.loss_fn) loss, loss_metrics = self.loss_fn( data=mb, global_valid_seqs=global_valid_seqs, diff --git a/nemo_rl/models/megatron/config.py b/nemo_rl/models/megatron/config.py index 5838e9d430..7a65ed1924 100644 --- a/nemo_rl/models/megatron/config.py +++ b/nemo_rl/models/megatron/config.py @@ -21,6 +21,8 @@ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer import MegatronModule +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams + class MegatronGenerationConfig(TypedDict): # Total GPU memory (in GB) allocated for KV cache buffers @@ -55,6 +57,7 @@ class RuntimeConfig(NamedTuple): optimizer_cpu_offload: bool offload_optimizer_for_logprob: bool is_generation_colocated: Optional[bool] + sampling_params: Optional[TrainingSamplingParams] final_padded_vocab_size: int diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 485dcd733b..7c0c4fa25a 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -66,6 +66,7 @@ except ImportError: HAVE_FSDP2 = False +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig @@ -194,7 +195,6 @@ def validate_and_set_config( hf_model_name, pretrained_path, weights_path, - tokenizer, ): # Handle generation colocation is_generation_colocated = None @@ -218,6 +218,16 @@ def validate_and_set_config( optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"] offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"] + # Sampling parameters configuration + sampling_params = None + if "generation" in config and config["generation"] is not None: + generation_cfg = config["generation"] + sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + # Reward models are not yet supported with Megatron. if "reward_model_cfg" in config and config["reward_model_cfg"]["enabled"]: raise NotImplementedError( @@ -242,6 +252,7 @@ def validate_and_set_config( optimizer_cpu_offload, offload_optimizer_for_logprob, is_generation_colocated, + sampling_params, final_padded_vocab_size, ) diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 19e02bde27..618b7803f7 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -29,6 +29,7 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.utils import StragglerDetector +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss import ( SequencePackingLossWrapper, prepare_loss_input, @@ -56,7 +57,6 @@ def model_forward( model: GPTModel, data_dict: BatchedDataDict[Any], - cfg: PolicyConfig, input_ids_cp_sharded: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, @@ -106,27 +106,26 @@ def model_forward( def apply_temperature_scaling( - logits: torch.Tensor, - cfg: PolicyConfig, + logits: torch.Tensor, sampling_params: Optional[TrainingSamplingParams] ) -> torch.Tensor: """Apply temperature scaling to logits. Args: logits: Logits tensor to scale - cfg: Policy configuration containing generation settings + sampling_params: Sampling parameters Returns: torch.Tensor: Temperature-scaled logits """ - if "generation" in cfg and cfg["generation"] is not None: - logits.div_(cfg["generation"]["temperature"]) + if sampling_params is not None and sampling_params.temperature != 1.0: + logits.div_(sampling_params.temperature) return logits def forward_with_post_processing_fn( data_iterator: Iterator[ProcessedMicrobatch], model: GPTModel, - cfg: PolicyConfig, + sampling_params: TrainingSamplingParams, post_processing_fn: PostProcessingFunction, defer_fp32_logits: Optional[bool] = False, global_valid_seqs: Optional[torch.Tensor] = None, @@ -142,7 +141,7 @@ def forward_with_post_processing_fn( Args: data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) model: The model to run forward pass on - cfg: Policy configuration dictionary + sampling_params: Sampling parameters post_processing_fn: Post-processing function to post-process the logits defer_fp32_logits: Whether to defer FP32 conversion of logits global_valid_seqs: Global valid sequence count for loss normalization @@ -169,7 +168,6 @@ def forward_with_post_processing_fn( output_tensor = model_forward( model=model, data_dict=data_dict, - cfg=cfg, input_ids_cp_sharded=input_ids_cp_sharded, position_ids=position_ids, attention_mask=attention_mask, @@ -187,7 +185,7 @@ def forward_with_post_processing_fn( # Temperature scaling is element-wise, directly applying it here. # Other sampling parameters like top-k and top-p need the logits from whole vocabulary, # so applying them when gathering logits from vocab parallel (called in LossPostProcessor and LogprobsPostProcessor). - apply_temperature_scaling(output_tensor, cfg) + apply_temperature_scaling(output_tensor, sampling_params) # Use type checking to dispatch to the correct post-processing method if isinstance(post_processing_fn, LossPostProcessor): @@ -218,7 +216,7 @@ def forward_with_post_processing_fn( def megatron_forward_backward( model: GPTModel, - cfg: PolicyConfig, + sampling_params: TrainingSamplingParams, data_iterator: Iterator[ProcessedMicrobatch], num_microbatches: int, seq_length: int, @@ -238,7 +236,7 @@ def megatron_forward_backward( Args: model: The model to train - cfg: Policy configuration dictionary + sampling_params: Sampling parameters data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) num_microbatches: Number of microbatches to process seq_length: Sequence length @@ -255,7 +253,7 @@ def megatron_forward_backward( """ forward_step = partial( forward_with_post_processing_fn, - cfg=cfg, + sampling_params=sampling_params, post_processing_fn=post_processing_fn, defer_fp32_logits=defer_fp32_logits, global_valid_seqs=global_valid_seqs, @@ -282,11 +280,13 @@ def __init__( cfg: PolicyConfig, num_microbatches: int = 1, cp_normalize: bool = True, + sampling_params: Optional[TrainingSamplingParams] = None, ): self.loss_fn = loss_fn self.cfg = cfg self.num_microbatches = num_microbatches self.cp_normalize = cp_normalize + self.sampling_params = sampling_params def __call__( self, @@ -310,12 +310,17 @@ def __call__( Returns: Callable: Function that takes output tensor and returns (loss, metrics) tuple """ + # wrap prepare_loss_input with sampling_params + prepare_loss_input_wrapped = partial( + prepare_loss_input, sampling_params=self.sampling_params + ) + # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: loss_fn_wrapped = SequencePackingLossWrapper( loss_fn=self.loss_fn, - prepare_fn=prepare_loss_input, + prepare_fn=prepare_loss_input_wrapped, cu_seqlens_q=packed_seq_params.cu_seqlens_q, cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, vocab_parallel_rank=get_tensor_model_parallel_rank(), @@ -326,7 +331,7 @@ def __call__( loss_fn_wrapped = partial( wrap_loss_fn_with_input_preparation, loss_fn=self.loss_fn, - prepare_fn=prepare_loss_input, + prepare_fn=prepare_loss_input_wrapped, vocab_parallel_rank=get_tensor_model_parallel_rank(), vocab_parallel_group=get_tensor_model_parallel_group(), context_parallel_group=get_context_parallel_group(), @@ -365,8 +370,9 @@ def _counteract_mcore_loss_averaging(*args, **kwargs): class LogprobsPostProcessor: - def __init__(self, cfg: PolicyConfig): + def __init__(self, cfg: PolicyConfig, sampling_params: TrainingSamplingParams): self.cfg = cfg + self.sampling_params = sampling_params def __call__( self, @@ -406,6 +412,7 @@ def processor_fn_inner(output_tensor): inference_only=True, cp_group=get_context_parallel_group(), chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) else: token_logprobs = from_parallel_logits_to_logprobs( @@ -416,6 +423,7 @@ def processor_fn_inner(output_tensor): tp_group=tp_grp, inference_only=True, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) # Prepend 0 logprob for first token to maintain same sequence length as input diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 280ad12dca..4ccc1c378f 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -781,8 +781,10 @@ def get_topk_logits( def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. - On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references - On exit: Restores original references and re-flips cuda/cpu + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. + Also disables top-k/top-p filtering since the reference policy's distribution + is different from the current policy, making filtered logprobs incompatible. + On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ with torch.no_grad(): try: @@ -796,10 +798,11 @@ def use_reference_model(self) -> Generator[None, None, None]: val = to_local_if_dtensor(v) val.copy_(self.reference_model_state_dict[k]) - # - self.model is the original reference_model, now on CUDA - # - curr_state_dict is the train model, now on CPU - - # Save and adjust sampling_params for reference model + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. saved_sampling_params = self.sampling_params if saved_sampling_params is not None: self.sampling_params = TrainingSamplingParams( @@ -810,6 +813,8 @@ def use_reference_model(self) -> Generator[None, None, None]: else: self.sampling_params = None + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU yield finally: diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index b640290f2f..63958e4554 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -48,6 +48,7 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from transformers import PreTrainedTokenizerBase +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -157,7 +158,6 @@ def __init__( hf_model_name, pretrained_path, weights_path, - tokenizer, ) self.megatron_cfg = runtime_config.megatron_cfg @@ -167,6 +167,7 @@ def __init__( runtime_config.offload_optimizer_for_logprob ) self.is_generation_colocated = runtime_config.is_generation_colocated + self.sampling_params = runtime_config.sampling_params self.final_padded_vocab_size = runtime_config.final_padded_vocab_size self.defer_fp32_logits = self.cfg["megatron_cfg"].get( @@ -317,6 +318,7 @@ def train( loss_fn=loss_fn, cfg=self.cfg, num_microbatches=num_microbatches, + sampling_params=self.sampling_params, ) rerun_state_machine = get_rerun_state_machine() @@ -328,7 +330,7 @@ def train( # Forward pass. losses_reduced = megatron_forward_backward( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, data_iterator=data_iterator, num_microbatches=num_microbatches, seq_length=padded_seq_length, @@ -486,14 +488,19 @@ def get_logprobs( straggler_timer=self.mcore_state.straggler_timer, ) + logprobs_post_processor = LogprobsPostProcessor( + cfg=self.cfg, + sampling_params=self.sampling_params, + ) + list_of_logprobs = megatron_forward_backward( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, data_iterator=mb_iterator, seq_length=padded_seq_length, mbs=micro_batch_size, num_microbatches=num_microbatches, - post_processing_fn=LogprobsPostProcessor(cfg=self.cfg), + post_processing_fn=logprobs_post_processor, forward_only=True, defer_fp32_logits=self.defer_fp32_logits, straggler_timer=self.mcore_state.straggler_timer, @@ -523,8 +530,10 @@ def get_logprobs( def use_reference_model(self): """Context manager that temporarily swaps the reference model and active model. - On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references - On exit: Restores original references and re-flips cuda/cpu + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. + Also disables top-k/top-p filtering since the reference policy's distribution + is different from the current policy, making filtered logprobs incompatible. + On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ ## disable overlap param gather when swapping weights if self.should_disable_forward_pre_hook: @@ -550,11 +559,29 @@ def use_reference_model(self): gc.collect() torch.cuda.empty_cache() + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None + # - self.model is the original reference_model, now on CUDA # - self.reference_model is the original model, now on CPU yield finally: + # Restore sampling_params + self.sampling_params = saved_sampling_params + # Restore original references and device placement for k, v in self.model.state_dict().items(): if isinstance(v, torch.Tensor): @@ -613,7 +640,7 @@ def get_topk_logits( list_of_outputs = megatron_forward_backward( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, data_iterator=mb_iterator, seq_length=padded_seq_length, mbs=micro_batch_size, From cdaa31a0cadab94962b0dae60cd201d6da7d79b1 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 02:27:08 -0800 Subject: [PATCH 07/20] update need_top_k_or_top_p_filtering Signed-off-by: Yuki Huang --- nemo_rl/algorithms/logits_sampling_utils.py | 51 +++++++++++-------- nemo_rl/algorithms/loss/utils.py | 4 +- nemo_rl/distributed/model_utils.py | 21 ++------ nemo_rl/models/automodel/train.py | 4 +- .../policy/workers/dtensor_policy_worker.py | 8 +-- pyrefly.toml | 1 + 6 files changed, 37 insertions(+), 52 deletions(-) diff --git a/nemo_rl/algorithms/logits_sampling_utils.py b/nemo_rl/algorithms/logits_sampling_utils.py index 54e9ceb992..db6fd43598 100644 --- a/nemo_rl/algorithms/logits_sampling_utils.py +++ b/nemo_rl/algorithms/logits_sampling_utils.py @@ -26,21 +26,6 @@ TOP_K_TOP_P_CHUNK_SIZE: int = 256 -def need_top_k_filtering(top_k: int | None) -> bool: - """Check if top-k filtering is needed.""" - return top_k is not None and top_k != -1 - - -def need_top_p_filtering(top_p: float | None) -> bool: - """Check if top-p filtering is needed.""" - return top_p is not None and top_p != 1.0 - - -def need_top_k_or_top_p_filtering(top_k: int | None, top_p: float | None) -> bool: - """Check if top-k or top-p filtering is needed.""" - return need_top_k_filtering(top_k) or need_top_p_filtering(top_p) - - @dataclass class TrainingSamplingParams: """Training-specific sampling parameters to match generation parameters. @@ -59,6 +44,28 @@ class TrainingSamplingParams: temperature: float = 1.0 +def _need_top_k_filtering(top_k: int | None) -> bool: + """Check if top-k filtering is needed.""" + return top_k is not None and top_k != -1 + + +def _need_top_p_filtering(top_p: float | None) -> bool: + """Check if top-p filtering is needed.""" + return top_p is not None and top_p != 1.0 + + +def need_top_k_or_top_p_filtering( + sampling_params: Optional[TrainingSamplingParams], +) -> bool: + """Check if top-k or top-p filtering is needed.""" + if sampling_params is None: + return False + + top_k = sampling_params.top_k + top_p = sampling_params.top_p + return _need_top_k_filtering(top_k) or _need_top_p_filtering(top_p) + + @torch.no_grad() def _apply_top_k_only_fn( logits: torch.Tensor, @@ -83,7 +90,7 @@ def _apply_top_k_only_fn( keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed. """ - if not need_top_k_filtering(top_k): + if not _need_top_k_filtering(top_k): return logits, None # Get top-k values and create mask @@ -133,8 +140,8 @@ def _apply_top_k_top_p_fn( keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked. """ - if not need_top_p_filtering(top_p): - if not need_top_k_filtering(top_k): + if not _need_top_p_filtering(top_p): + if not _need_top_k_filtering(top_k): return logits, None # Avoid sorting vocab for top-k only case filtered_logits, top_k_keep_mask = _apply_top_k_only_fn(logits, top_k) @@ -162,7 +169,7 @@ def _apply_top_k_top_p_fn( logits_sort, logits_idx = chunk_logits.sort(dim=-1, descending=False) top_k_keep_mask_chunk = None - if need_top_k_filtering(top_k): + if _need_top_k_filtering(top_k): assert top_k is not None # Type narrowing # Apply top-k first top_k_index = logits_sort.size(-1) - top_k @@ -203,7 +210,7 @@ def _apply_top_k_top_p_fn( return filtered_logits, keep_mask -class ApplyTopKTopP(torch.autograd.Function): +class _ApplyTopKTopP(torch.autograd.Function): """Autograd function for top-k and top-p filtering with proper gradient handling.""" @staticmethod @@ -273,6 +280,6 @@ def apply_top_k_top_p( keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked. """ - if not need_top_k_or_top_p_filtering(top_k, top_p): + if not _need_top_k_filtering(top_k) and not _need_top_p_filtering(top_p): return logits, None - return ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) + return _ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index 7d65440ed8..fbf12ed12d 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -102,9 +102,7 @@ def prepare_loss_input( ) # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now - if sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ): + if need_top_k_or_top_p_filtering(sampling_params): # mask out negative infinity logprobs mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) logprobs = _mask_out_neg_inf_logprobs( diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index c71de218a0..a5837e7753 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -781,12 +781,7 @@ def dtensor_from_parallel_logits_to_logprobs( else: target = target.roll(shifts=-1, dims=-1) - use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ) - - if use_sampling: - assert sampling_params is not None # Type narrowing + if need_top_k_or_top_p_filtering(sampling_params): if chunk_size is not None: logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, @@ -882,12 +877,7 @@ def from_parallel_logits_to_logprobs( cp_rank = torch.distributed.get_rank(cp_group) target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) - use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ) - - if use_sampling: - assert sampling_params is not None # Type narrowing + if need_top_k_or_top_p_filtering(sampling_params): if chunk_size is not None: logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, @@ -1002,12 +992,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation - use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ) - - if use_sampling: - assert sampling_params is not None # Type narrowing + if need_top_k_or_top_p_filtering(sampling_params): if chunk_size is not None: probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore vocab_parallel_logits, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 42fb1d1203..5fe8a600d7 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -156,9 +156,7 @@ def apply_top_k_top_p_filtering_for_local_logits( Returns: torch.Tensor: Filtered logits """ - if sampling_params is not None and need_top_k_or_top_p_filtering( - sampling_params.top_k, sampling_params.top_p - ): + if need_top_k_or_top_p_filtering(sampling_params): logits, _ = apply_top_k_top_p( logits, top_k=sampling_params.top_k, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index df563456ef..9b6a823fd4 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -49,8 +49,7 @@ from nemo_rl.algorithms.logits_sampling_utils import ( TrainingSamplingParams, apply_top_k_top_p, - need_top_k_filtering, - need_top_p_filtering, + need_top_k_or_top_p_filtering, ) from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType @@ -499,10 +498,7 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" - if self.sampling_params is not None and ( - need_top_k_filtering(self.sampling_params.top_k) - or need_top_p_filtering(self.sampling_params.top_p) - ): + if need_top_k_or_top_p_filtering(self.sampling_params): logits, _ = apply_top_k_top_p( logits, top_k=self.sampling_params.top_k, diff --git a/pyrefly.toml b/pyrefly.toml index 4c4ae33fa2..e43308f4fd 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -39,6 +39,7 @@ project-includes = [ "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", "nemo_rl/algorithms/advantage_estimator.py", + "nemo_rl/algorithms/logits_sampling_utils.py", "nemo_rl/algorithms/loss/__init__.py", "nemo_rl/algorithms/loss/interfaces.py", "nemo_rl/algorithms/loss/utils.py", From c40e49707e9033ca5931e9ba3d777e7da1a2052d Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 08:22:37 -0800 Subject: [PATCH 08/20] fix test Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/wrapper.py | 2 +- nemo_rl/models/automodel/setup.py | 21 +++--- nemo_rl/models/automodel/train.py | 6 +- nemo_rl/models/megatron/setup.py | 23 ++++--- .../policy/workers/dtensor_policy_worker.py | 13 ++-- .../workers/dtensor_policy_worker_v2.py | 2 +- .../sequence_packing_gradient_actor.py | 4 +- tests/unit/algorithms/test_loss_functions.py | 66 ++++++++++--------- tests/unit/test_recipes_and_test_suites.py | 6 +- 9 files changed, 72 insertions(+), 71 deletions(-) diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index 5e26d9b3e7..b669c494d3 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -165,7 +165,7 @@ def wrap_loss_fn_with_input_preparation( ) -> tuple[Tensor, dict[str, Any]]: """Wraps a loss function to handle input preparation for megatron policy worker.""" # prepare loss input - loss_input = prepare_fn( + loss_input, data = prepare_fn( logits=next_token_logits, data=data, loss_fn=loss_fn, diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index a5c50433db..f2c36f7e97 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -73,8 +73,17 @@ def validate_and_prepare_config( # Set basic configuration is_vlm = processor is not None is_generation_colocated = None + sampling_params = None if "generation" in config and config["generation"] is not None: - is_generation_colocated = config["generation"]["colocated"]["enabled"] + generation_cfg = config["generation"] + # set generation colocated + is_generation_colocated = generation_cfg["colocated"]["enabled"] + # set sampling params + sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. @@ -91,16 +100,6 @@ def validate_and_prepare_config( raise ValueError(f"Unknown precision: {precision}") dtype = STRING_TO_DTYPE[precision] - # Sampling parameters configuration - sampling_params = None - if "generation" in config and config["generation"] is not None: - generation_cfg = config["generation"] - sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), - ) - # Get other configuration values cpu_offload = config["dtensor_cfg"]["cpu_offload"] offload_optimizer_for_logprob = config.get("offload_optimizer_for_logprob", False) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 5fe8a600d7..86b763b8e7 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -359,7 +359,7 @@ def forward_with_post_processing_fn( def automodel_forward_backward( model: nn.Module, - cfg: PolicyConfig, + sampling_params: TrainingSamplingParams, data_iterator: Iterator[ProcessedMicrobatch], post_processing_fn: PostProcessingFunction, forward_only: bool = False, @@ -385,7 +385,7 @@ def automodel_forward_backward( Args: model: The model to train - cfg: Configuration dictionary + sampling_params: Sampling parameters data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) num_microbatches: Number of microbatches to process post_processing_fn: Post-processing function to apply to the logits @@ -430,7 +430,7 @@ def automodel_forward_backward( # Forward pass with post-processing result, metrics, _ = forward_with_post_processing_fn( model=model, - cfg=cfg, + sampling_params=sampling_params, post_processing_fn=post_processing_fn, processed_mb=processed_mb, is_reward_model=is_reward_model, diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 7c0c4fa25a..0b584cab90 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -196,10 +196,19 @@ def validate_and_set_config( pretrained_path, weights_path, ): - # Handle generation colocation + # Handle generation configuration is_generation_colocated = None + sampling_params = None if "generation" in config and config["generation"] is not None: - is_generation_colocated = config["generation"]["colocated"]["enabled"] + generation_cfg = config["generation"] + # set generation colocated + is_generation_colocated = generation_cfg["colocated"]["enabled"] + # set sampling params + sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. @@ -218,16 +227,6 @@ def validate_and_set_config( optimizer_cpu_offload = config["megatron_cfg"]["optimizer"]["optimizer_cpu_offload"] offload_optimizer_for_logprob = config["offload_optimizer_for_logprob"] - # Sampling parameters configuration - sampling_params = None - if "generation" in config and config["generation"] is not None: - generation_cfg = config["generation"] - sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), - ) - # Reward models are not yet supported with Megatron. if "reward_model_cfg" in config and config["reward_model_cfg"]["enabled"]: raise NotImplementedError( diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 9b6a823fd4..be47eba73f 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -174,18 +174,17 @@ def __init__( print(f"Initializing DTensorPolicyWorker with is_vlm={self.is_vlm}") self.is_generation_colocated = None + self.sampling_params = None if "generation" in config and config["generation"] is not None: - self.is_generation_colocated = config["generation"]["colocated"]["enabled"] - - if "generation" in self.cfg and self.cfg["generation"] is not None: - generation_cfg = self.cfg["generation"] + generation_cfg = config["generation"] + # set generation colocated + self.is_generation_colocated = generation_cfg["colocated"]["enabled"] + # set sampling params self.sampling_params = TrainingSamplingParams( top_k=generation_cfg.get("top_k", None), top_p=generation_cfg.get("top_p", 1.0), temperature=generation_cfg.get("temperature", 1.0), ) - else: - self.sampling_params = None # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. @@ -816,7 +815,7 @@ def train( global_valid_toks, ) else: - loss_input = prepare_loss_input(logits, mb, loss_fn) + loss_input, mb = prepare_loss_input(logits, mb, loss_fn) loss, loss_metrics = loss_fn( data=mb, global_valid_seqs=global_valid_seqs, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4ccc1c378f..46f912726d 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -428,7 +428,7 @@ def on_microbatch_start(mb_idx): # Use automodel_forward_backward for the training loop mb_results = automodel_forward_backward( model=self.model, - cfg=self.cfg, + sampling_params=self.sampling_params, data_iterator=processed_iterator, post_processing_fn=loss_post_processor, forward_only=eval_mode, diff --git a/tests/unit/algorithms/sequence_packing_gradient_actor.py b/tests/unit/algorithms/sequence_packing_gradient_actor.py index a5e750d358..1c57975f11 100644 --- a/tests/unit/algorithms/sequence_packing_gradient_actor.py +++ b/tests/unit/algorithms/sequence_packing_gradient_actor.py @@ -150,7 +150,9 @@ def test_sequence_packing_gradients(self): global_valid_seqs = torch.tensor(batch_size, dtype=torch.float, device="cuda") # Forward pass - loss_input = prepare_loss_input(baseline_logits, data_dict, base_loss_fn) + loss_input, data_dict = prepare_loss_input( + baseline_logits, data_dict, base_loss_fn + ) baseline_loss, _ = base_loss_fn( data=data_dict, global_valid_seqs=global_valid_seqs, diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 84ee67eb93..a49eff739d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -92,7 +92,7 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - loss_input = prepare_loss_input(next_token_logits, data, loss_fn) + loss_input, data = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -118,7 +118,7 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) - loss_input = prepare_loss_input(next_token_logits, data, loss_fn) + loss_input, data = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -154,7 +154,7 @@ def test_dpo_loss(): } ) - loss_input = prepare_loss_input(next_token_logits, data, loss_fn) + loss_input, data = prepare_loss_input(next_token_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -177,7 +177,7 @@ def test_dpo_loss(): } ) - loss_input = prepare_loss_input(next_token_logits, data, loss_fn_with_sft) + loss_input, data = prepare_loss_input(next_token_logits, data, loss_fn_with_sft) loss_sft, _ = loss_fn_with_sft( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -269,7 +269,7 @@ def test_dpo_loss_varying_sequence_lengths(): ) # Compute no averaging loss - loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_no_avg) + loss_input, data = prepare_loss_input(next_token_logits, data, dpo_loss_fn_no_avg) _, metrics = dpo_loss_fn_no_avg( data=data, global_valid_seqs=torch.sum(sample_mask), @@ -278,7 +278,7 @@ def test_dpo_loss_varying_sequence_lengths(): ) # Compute averaging loss - loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_avg) + loss_input, data = prepare_loss_input(next_token_logits, data, dpo_loss_fn_avg) _, metrics_avg = dpo_loss_fn_avg( data=data, global_valid_seqs=torch.sum(sample_mask), @@ -334,7 +334,9 @@ def test_dpo_sft_matches_nll_loss(): # Compute NLL loss nll_loss_fn = NLLLossFn() - loss_input = prepare_loss_input(next_token_logits[::2], sft_data, nll_loss_fn) + loss_input, sft_data = prepare_loss_input( + next_token_logits[::2], sft_data, nll_loss_fn + ) nll_loss, _ = nll_loss_fn( data=sft_data, global_valid_seqs=None, @@ -354,7 +356,7 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) - loss_input = prepare_loss_input(next_token_logits, dpo_data, dpo_loss_fn) + loss_input, dpo_data = prepare_loss_input(next_token_logits, dpo_data, dpo_loss_fn) dpo_loss, _ = dpo_loss_fn( data=dpo_data, global_valid_seqs=torch.sum(dpo_data["sample_mask"]), @@ -518,7 +520,7 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -566,7 +568,7 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -612,7 +614,7 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, metrics = loss_fn( data=data, @@ -723,7 +725,7 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -764,7 +766,7 @@ def test_clipped_pg_loss_masking(): cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample @@ -825,7 +827,7 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] - loss_input = prepare_loss_input(logits_only_b0, data_only_b0, loss_fn) + loss_input, data_only_b0 = prepare_loss_input(logits_only_b0, data_only_b0, loss_fn) loss_only_b0, _ = loss_fn( data=data_only_b0, global_valid_seqs=torch.sum(data_only_b0["sample_mask"]), @@ -851,7 +853,7 @@ def test_clipped_pg_loss_zero_mask(): cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) # Set token mask to all zeros data["token_mask"] = torch.zeros_like(data["token_mask"]) @@ -1002,7 +1004,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1135,7 +1137,7 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1183,7 +1185,7 @@ def test_clipped_pg_loss_icepop_importance_sampling(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1227,7 +1229,7 @@ def test_clipped_pg_loss_seq_mask_tis(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1359,7 +1361,7 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1409,7 +1411,7 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) _, metrics = loss_fn( data=data, @@ -1494,7 +1496,7 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1593,7 +1595,7 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1695,7 +1697,7 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) - loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + loss_input, data = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( data=data, @@ -1758,7 +1760,7 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk): } ) - loss_input = prepare_loss_input(student_logits, data, loss_fn) + loss_input, data = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1803,7 +1805,7 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk): } ) - loss_input = prepare_loss_input(student_logits, data, loss_fn) + loss_input, data = prepare_loss_input(student_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1843,7 +1845,7 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): - _ = prepare_loss_input(student_logits, data, loss_fn) + prepare_loss_input(student_logits, data, loss_fn) def test_distillation_loss_gradient_flow(): @@ -1861,7 +1863,7 @@ def test_distillation_loss_gradient_flow(): } ) - loss_input = prepare_loss_input(student_logits, data, loss_fn) + loss_input, data = prepare_loss_input(student_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1895,7 +1897,7 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) - loss_input = prepare_loss_input(zero_logits, data, loss_fn) + loss_input, data = prepare_loss_input(zero_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1909,7 +1911,7 @@ def test_distillation_loss_edge_cases(): # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 - loss_input = prepare_loss_input(large_logits, data, loss_fn) + loss_input, data = prepare_loss_input(large_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1923,7 +1925,7 @@ def test_distillation_loss_edge_cases(): # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 - loss_input = prepare_loss_input(small_logits, data, loss_fn) + loss_input, data = prepare_loss_input(small_logits, data, loss_fn) loss, _ = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), @@ -1973,7 +1975,7 @@ def test_distillation_loss_fn_call(): } ) - loss_input = prepare_loss_input(student_logits, data, loss_fn) + loss_input, data = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( data=data, global_valid_seqs=torch.sum(data["sample_mask"]), diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 1fb39f41d2..6c4f0bed80 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -217,7 +217,7 @@ def test_all_recipe_yamls_accounted_for_in_test_suites( ) -def test_nightly_compute_stays_below_1300_hours(nightly_test_suite, tracker): +def test_nightly_compute_stays_below_1320_hours(nightly_test_suite, tracker): command = f"DRYRUN=1 HF_HOME=... HF_DATASETS_CACHE=... CONTAINER= ACCOUNT= PARTITION= ./tools/launch {' '.join(nightly_test_suite)}" print(f"Running command: {command}") @@ -249,8 +249,8 @@ def test_nightly_compute_stays_below_1300_hours(nightly_test_suite, tracker): f"Last line of output was not as expected: '{last_line}'" ) total_gpu_hours = float(last_line.split(":")[-1].strip()) - assert total_gpu_hours <= 1300, ( - f"Total GPU hours exceeded 1300: {last_line}. We should revisit the test suites to reduce the total GPU hours." + assert total_gpu_hours <= 1320, ( + f"Total GPU hours exceeded 1320: {last_line}. We should revisit the test suites to reduce the total GPU hours." ) tracker.track("total_nightly_gpu_hours", total_gpu_hours) From 5c2abebd06da5e77aba790db04d1c3d4b38a13f4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 3 Mar 2026 21:20:15 -0800 Subject: [PATCH 09/20] apply filter prev_logprobs before compute_and_apply_seq_logprob_error_masking Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 3 +++ nemo_rl/algorithms/loss/utils.py | 41 +++---------------------------- nemo_rl/algorithms/utils.py | 32 ++++++++++++++++++++++++ nemo_rl/models/automodel/train.py | 39 +++++++++++++++++++---------- 4 files changed, 64 insertions(+), 51 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c060a05a50..e825f761bf 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1736,6 +1736,8 @@ def grpo_train( { "input_ids": train_data["input_ids"], "input_lengths": train_data["input_lengths"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], **extra_multimodal_data, } ) @@ -1767,6 +1769,7 @@ def grpo_train( "seq_logprob_error_threshold" ], ) + # Compute advantages with adv_estimator using correct mask and logprobs with timer.time("advantage_calculation"): print("▶ Computing advantages...", flush=True) diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index fbf12ed12d..70edc220e0 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -21,6 +21,7 @@ need_top_k_or_top_p_filtering, ) from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( get_distillation_topk_logprobs_from_logits, @@ -28,38 +29,6 @@ ) -def _mask_out_neg_inf_logprobs( - logprobs: torch.Tensor, mask: torch.Tensor, logprobs_name: str -) -> torch.Tensor: - """Mask out negative infinity log probabilities. - - Handling sampling mask mismatch: - vLLM samples token X from top-k/p filtered distribution -> generation_logprobs[X] is always finite (e.g., -5.41) - during training: policy computes logprobs with same top-k/p settings, but the distribution can be slightly different - token X may fall outside the training policy's top-k/p set -> curr_logprobs[X] = -inf, prev_logprobs[X] = -inf - Detect positions with -inf in any logprobs (generation_logprobs is always finite for valid tokens) - - Args: - logprobs: Log probabilities. - mask: Mask. - - Returns: - Masked log probabilities. - """ - is_neginf = torch.isinf(logprobs) - neginf_count = (is_neginf & mask.bool()).sum().item() - if neginf_count > 0: - print( - f"[WARNING]: {neginf_count}/{int(mask.sum().item())} valid tokens have -inf in {logprobs_name} " - "(policy top-k/top-p mismatch). Masking out these positions." - ) - - mask = mask * (~is_neginf).float() - logprobs = torch.where(mask.bool(), logprobs, 0.0) - - return logprobs - - def prepare_loss_input( logits: torch.Tensor, data: BatchedDataDict[Any], @@ -104,13 +73,9 @@ def prepare_loss_input( # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now if need_top_k_or_top_p_filtering(sampling_params): # mask out negative infinity logprobs + # prev_logprobs is already masked out in the previous step mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) - logprobs = _mask_out_neg_inf_logprobs( - logprobs, mask[:, 1:], "curr_logprobs" - ) - data["prev_logprobs"] = _mask_out_neg_inf_logprobs( - data["prev_logprobs"], mask, "prev_logprobs" - ) + logprobs = mask_out_neg_inf_logprobs(logprobs, mask[:, 1:], "curr_logprobs") # compute unfiltered logprobs for reference policy KL penalty if ( diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 8e632ca5ee..c48d3b7937 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -182,6 +182,38 @@ def masked_mean( return torch.sum(values * mask, dim=dim) / (normalization_factor + 1e-8) +def mask_out_neg_inf_logprobs( + logprobs: torch.Tensor, mask: torch.Tensor, logprobs_name: str +) -> torch.Tensor: + """Mask out negative infinity log probabilities. + + Handling sampling mask mismatch: + vLLM samples token X from top-k/p filtered distribution -> generation_logprobs[X] is always finite (e.g., -5.41) + during training: policy computes logprobs with same top-k/p settings, but the distribution can be slightly different + token X may fall outside the training policy's top-k/p set -> curr_logprobs[X] = -inf, prev_logprobs[X] = -inf + Detect positions with -inf in any logprobs (generation_logprobs is always finite for valid tokens) + + Args: + logprobs: Log probabilities. + mask: Mask. + + Returns: + Masked log probabilities. + """ + is_neginf = torch.isinf(logprobs) + neginf_count = (is_neginf & mask.bool()).sum().item() + if neginf_count > 0: + print( + f"[WARNING]: {neginf_count}/{int(mask.sum().item())} valid tokens have -inf in {logprobs_name} " + "(policy top-k/top-p mismatch). Masking out these positions." + ) + + mask = mask * (~is_neginf).float() + logprobs = torch.where(mask.bool(), logprobs, 0.0) + + return logprobs + + def set_seed(seed: int) -> None: """Sets the seed for python, numpy, and pytorch.""" random.seed(seed) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 86b763b8e7..4886763281 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -39,6 +39,7 @@ ) from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -323,7 +324,7 @@ def forward_with_post_processing_fn( if isinstance(post_processing_fn, LossPostProcessor): result, metrics = post_processing_fn( logits=logits, - mb=data_dict, + data_dict=data_dict, processed_inputs=processed_inputs, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, @@ -334,8 +335,8 @@ def forward_with_post_processing_fn( ): result = post_processing_fn( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=data_dict["input_lengths"], original_batch_size=processed_mb.original_batch_size, original_seq_len=processed_mb.original_seq_len, sequence_dim=sequence_dim, @@ -517,7 +518,7 @@ def __init__( def __call__( self, logits: torch.Tensor, - mb: BatchedDataDict[Any], + data_dict: BatchedDataDict[Any], processed_inputs: ProcessedInputs, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, @@ -527,7 +528,7 @@ def __call__( Args: logits: Model output logits - mb: Microbatch data + data_dict: Microbatch data processed_inputs: Processed inputs global_valid_seqs: Global valid sequence count global_valid_toks: Global valid token count @@ -538,8 +539,8 @@ def __call__( """ # Handle CP redistribution if self.cp_size > 1: - _, mb = prepare_data_for_cp( - mb, processed_inputs, self.cp_mesh, sequence_dim + _, data_dict = prepare_data_for_cp( + data_dict, processed_inputs, self.cp_mesh, sequence_dim ) logits = redistribute_logits_for_cp( logits, self.device_mesh, self.cp_mesh, sequence_dim @@ -559,14 +560,16 @@ def __call__( ) loss, loss_metrics = loss_fn( logits, - mb, + data_dict, global_valid_seqs, global_valid_toks, ) else: - loss_input, mb = prepare_loss_input_wrapped(logits, mb, self.loss_fn) + loss_input, data_dict = prepare_loss_input_wrapped( + logits, data_dict, self.loss_fn + ) loss, loss_metrics = self.loss_fn( - data=mb, + data=data_dict, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, **loss_input, @@ -611,8 +614,8 @@ def __init__( def __call__( self, logits: torch.Tensor, + data_dict: BatchedDataDict[Any], processed_inputs: ProcessedInputs, - input_lengths: torch.Tensor, original_batch_size: int, original_seq_len: int, sequence_dim: int = 1, @@ -621,8 +624,8 @@ def __call__( Args: logits: Model output logits + data_dict: Microbatch data processed_inputs: Processed inputs - input_lengths: Sequence lengths original_batch_size: Original batch size before packing original_seq_len: Original sequence length before packing sequence_dim: Sequence dimension @@ -631,6 +634,7 @@ def __call__( Token log probabilities tensor [batch_size, seq_length] """ seq_len = processed_inputs.seq_len + input_lengths = data_dict["input_lengths"] if self.cp_size > 1: seq_index_tensor = ( @@ -709,6 +713,13 @@ def __call__( post_attention_mask[i, :length] = 1 token_logprobs = token_logprobs * post_attention_mask + # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now + if need_top_k_or_top_p_filtering(self.sampling_params): + mask = data_dict["token_mask"] * data_dict["sample_mask"].unsqueeze(-1) + token_logprobs = mask_out_neg_inf_logprobs( + token_logprobs, mask, "prev_logprobs" + ) + return token_logprobs def _compute_local_logprobs( @@ -804,8 +815,8 @@ def __init__( def __call__( self, logits: torch.Tensor, + data_dict: BatchedDataDict[Any], processed_inputs: ProcessedInputs, - input_lengths: torch.Tensor, original_batch_size: int, original_seq_len: int, sequence_dim: int = 1, @@ -814,8 +825,8 @@ def __call__( Args: logits: Model output logits + data_dict: Microbatch data processed_inputs: Processed inputs - input_lengths: Sequence lengths original_batch_size: Original batch size before packing original_seq_len: Original sequence length before packing sequence_dim: Sequence dimension @@ -823,6 +834,8 @@ def __call__( Returns: Tuple of (top-k values, top-k indices) tensors """ + input_lengths = data_dict["input_lengths"] + if self.cp_size > 1: logits = redistribute_logits_for_cp( logits, self.device_mesh, self.cp_mesh, sequence_dim From 50f783799c3f2689f171e379e7f6eeb38dc18bc7 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 00:19:12 -0800 Subject: [PATCH 10/20] update megatron and dtensorv1, and fix unit test Signed-off-by: Yuki Huang --- nemo_rl/models/automodel/train.py | 10 +- nemo_rl/models/megatron/train.py | 30 ++++-- .../policy/workers/dtensor_policy_worker.py | 11 ++- .../workers/dtensor_policy_worker_v2.py | 8 +- .../policy/workers/megatron_policy_worker.py | 6 +- .../sequence_packing_gradient_actor.py | 3 +- .../models/automodel/test_automodel_train.py | 82 ++++++++-------- tests/unit/models/megatron/test_train.py | 94 ++++++++----------- 8 files changed, 125 insertions(+), 119 deletions(-) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 4886763281..7a0396f4e1 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -258,13 +258,13 @@ def prepare_data_for_cp( def forward_with_post_processing_fn( model: nn.Module, - sampling_params: TrainingSamplingParams, post_processing_fn: PostProcessingFunction, processed_mb: ProcessedMicrobatch, is_reward_model: bool = False, allow_flash_attn_args: bool = True, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, + sampling_params: Optional[TrainingSamplingParams] = None, sequence_dim: int = 1, ) -> Tuple[Any, dict[str, Any], ProcessedMicrobatch]: """Perform forward pass with pre-processed microbatch and apply post-processing. @@ -278,13 +278,13 @@ def forward_with_post_processing_fn( Args: model: The model to run forward pass on - sampling_params: Sampling parameters post_processing_fn: Post-processing function to apply to the logits processed_mb: Pre-fetched ProcessedMicrobatch containing data and processed inputs is_reward_model: Whether this is a reward model allow_flash_attn_args: Whether to pass flash_attn_kwargs to model global_valid_seqs: Global valid sequence count for loss normalization global_valid_toks: Global valid token count for loss normalization + sampling_params: Sampling parameters (top-k, top-p, temperature) sequence_dim: Sequence dimension Returns: @@ -360,7 +360,6 @@ def forward_with_post_processing_fn( def automodel_forward_backward( model: nn.Module, - sampling_params: TrainingSamplingParams, data_iterator: Iterator[ProcessedMicrobatch], post_processing_fn: PostProcessingFunction, forward_only: bool = False, @@ -368,6 +367,7 @@ def automodel_forward_backward( allow_flash_attn_args: bool = True, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, + sampling_params: Optional[TrainingSamplingParams] = None, sequence_dim: int = 1, dp_size: int = 1, cp_size: int = 1, @@ -386,7 +386,6 @@ def automodel_forward_backward( Args: model: The model to train - sampling_params: Sampling parameters data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) num_microbatches: Number of microbatches to process post_processing_fn: Post-processing function to apply to the logits @@ -395,6 +394,7 @@ def automodel_forward_backward( allow_flash_attn_args: Whether to pass flash_attn_kwargs to model global_valid_seqs: Global valid sequence count for loss normalization global_valid_toks: Global valid token count for loss normalization + sampling_params: Sampling parameters (top-k, top-p, temperature) sequence_dim: Sequence dimension dp_size: Data parallel size cp_size: Context parallel size @@ -431,13 +431,13 @@ def automodel_forward_backward( # Forward pass with post-processing result, metrics, _ = forward_with_post_processing_fn( model=model, - sampling_params=sampling_params, post_processing_fn=post_processing_fn, processed_mb=processed_mb, is_reward_model=is_reward_model, allow_flash_attn_args=allow_flash_attn_args, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, + sampling_params=sampling_params, sequence_dim=sequence_dim, ) diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 618b7803f7..cfbb913395 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -29,13 +29,17 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.utils import StragglerDetector -from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams +from nemo_rl.algorithms.logits_sampling_utils import ( + TrainingSamplingParams, + need_top_k_or_top_p_filtering, +) from nemo_rl.algorithms.loss import ( SequencePackingLossWrapper, prepare_loss_input, wrap_loss_fn_with_input_preparation, ) from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -125,11 +129,11 @@ def apply_temperature_scaling( def forward_with_post_processing_fn( data_iterator: Iterator[ProcessedMicrobatch], model: GPTModel, - sampling_params: TrainingSamplingParams, post_processing_fn: PostProcessingFunction, defer_fp32_logits: Optional[bool] = False, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, + sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, ) -> Tuple[torch.Tensor, Callable]: """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. @@ -141,11 +145,11 @@ def forward_with_post_processing_fn( Args: data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) model: The model to run forward pass on - sampling_params: Sampling parameters post_processing_fn: Post-processing function to post-process the logits defer_fp32_logits: Whether to defer FP32 conversion of logits global_valid_seqs: Global valid sequence count for loss normalization global_valid_toks: Global valid token count for loss normalization + sampling_params: Sampling parameters (top-k, top-p, temperature) straggler_timer: Straggler detector for profiling the forward pass Returns: @@ -216,7 +220,6 @@ def forward_with_post_processing_fn( def megatron_forward_backward( model: GPTModel, - sampling_params: TrainingSamplingParams, data_iterator: Iterator[ProcessedMicrobatch], num_microbatches: int, seq_length: int, @@ -226,6 +229,7 @@ def megatron_forward_backward( defer_fp32_logits: Optional[bool] = False, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, + sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, ) -> Any: """Execute forward and backward passes using Megatron's utilities. @@ -236,7 +240,6 @@ def megatron_forward_backward( Args: model: The model to train - sampling_params: Sampling parameters data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) num_microbatches: Number of microbatches to process seq_length: Sequence length @@ -246,6 +249,7 @@ def megatron_forward_backward( defer_fp32_logits: Whether to skip the conversion of logits to fp32 global_valid_seqs: Global valid sequence count for loss normalization global_valid_toks: Global valid token count for loss normalization + sampling_params: Sampling parameters (top-k, top-p, temperature) straggler_timer: Straggler detector for profiling the forward pass Returns: @@ -253,11 +257,11 @@ def megatron_forward_backward( """ forward_step = partial( forward_with_post_processing_fn, - sampling_params=sampling_params, post_processing_fn=post_processing_fn, defer_fp32_logits=defer_fp32_logits, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, + sampling_params=sampling_params, straggler_timer=straggler_timer, ) forward_backward_func = get_forward_backward_func() @@ -370,7 +374,11 @@ def _counteract_mcore_loss_averaging(*args, **kwargs): class LogprobsPostProcessor: - def __init__(self, cfg: PolicyConfig, sampling_params: TrainingSamplingParams): + def __init__( + self, + cfg: PolicyConfig, + sampling_params: Optional[TrainingSamplingParams] = None, + ): self.cfg = cfg self.sampling_params = sampling_params @@ -430,6 +438,14 @@ def processor_fn_inner(output_tensor): token_logprobs = torch.cat( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 ) + + # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now + if need_top_k_or_top_p_filtering(self.sampling_params): + mask = data_dict["token_mask"] * data_dict["sample_mask"].unsqueeze(-1) + token_logprobs = mask_out_neg_inf_logprobs( + token_logprobs, mask, "prev_logprobs" + ) + return torch.tensor(0.0, device=token_logprobs.device), { "logprobs": token_logprobs } diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index be47eba73f..7234ec2913 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -53,6 +53,7 @@ ) from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType +from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -1213,8 +1214,16 @@ def get_logprobs( lp, (0, padding_needed), mode="constant", value=0.0 ) all_log_probs_padded.append(lp) - return_data["logprobs"] = torch.cat(all_log_probs_padded, dim=0).cpu() + token_logprobs = torch.cat(all_log_probs_padded, dim=0) + # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now + if need_top_k_or_top_p_filtering(self.sampling_params): + mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1) + token_logprobs = mask_out_neg_inf_logprobs( + token_logprobs, mask, "prev_logprobs" + ) + + return_data["logprobs"] = token_logprobs.cpu() return return_data # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 46f912726d..122598e64c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -428,7 +428,6 @@ def on_microbatch_start(mb_idx): # Use automodel_forward_backward for the training loop mb_results = automodel_forward_backward( model=self.model, - sampling_params=self.sampling_params, data_iterator=processed_iterator, post_processing_fn=loss_post_processor, forward_only=eval_mode, @@ -436,6 +435,7 @@ def on_microbatch_start(mb_idx): allow_flash_attn_args=self.allow_flash_attn_args, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, + sampling_params=self.sampling_params, sequence_dim=sequence_dim, dp_size=self.dp_size, cp_size=self.cp_size, @@ -571,11 +571,11 @@ def get_logprobs( # Use forward_with_post_processing_fn for forward pass and post-processing token_logprobs, _metrics, _ = forward_with_post_processing_fn( model=self.model, - sampling_params=self.sampling_params, post_processing_fn=logprobs_post_processor, processed_mb=processed_mb, is_reward_model=False, allow_flash_attn_args=self.allow_flash_attn_args, + sampling_params=self.sampling_params, sequence_dim=sequence_dim, ) @@ -640,11 +640,11 @@ def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: # Use forward_with_post_processing_fn for forward pass and post-processing rm_scores, _metrics, _ = forward_with_post_processing_fn( model=self.model, - sampling_params=self.sampling_params, post_processing_fn=score_post_processor, processed_mb=processed_mb, is_reward_model=True, allow_flash_attn_args=False, + sampling_params=self.sampling_params, sequence_dim=sequence_dim, ) @@ -730,11 +730,11 @@ def get_topk_logits( # Use forward_with_post_processing_fn for forward pass and post-processing (vals, idx), _metrics, _ = forward_with_post_processing_fn( model=self.model, - sampling_params=self.sampling_params, post_processing_fn=topk_post_processor, processed_mb=processed_mb, is_reward_model=False, allow_flash_attn_args=self.allow_flash_attn_args, + sampling_params=self.sampling_params, sequence_dim=sequence_dim, ) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 63958e4554..073b6bd9de 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -330,7 +330,6 @@ def train( # Forward pass. losses_reduced = megatron_forward_backward( model=self.model, - sampling_params=self.sampling_params, data_iterator=data_iterator, num_microbatches=num_microbatches, seq_length=padded_seq_length, @@ -340,6 +339,7 @@ def train( defer_fp32_logits=self.defer_fp32_logits, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, + sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, ) @@ -495,7 +495,6 @@ def get_logprobs( list_of_logprobs = megatron_forward_backward( model=self.model, - sampling_params=self.sampling_params, data_iterator=mb_iterator, seq_length=padded_seq_length, mbs=micro_batch_size, @@ -503,6 +502,7 @@ def get_logprobs( post_processing_fn=logprobs_post_processor, forward_only=True, defer_fp32_logits=self.defer_fp32_logits, + sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, ) @@ -640,7 +640,6 @@ def get_topk_logits( list_of_outputs = megatron_forward_backward( model=self.model, - sampling_params=self.sampling_params, data_iterator=mb_iterator, seq_length=padded_seq_length, mbs=micro_batch_size, @@ -648,6 +647,7 @@ def get_topk_logits( post_processing_fn=TopkLogitsPostProcessor(cfg=self.cfg, k=k), forward_only=True, defer_fp32_logits=self.defer_fp32_logits, + sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, ) diff --git a/tests/unit/algorithms/sequence_packing_gradient_actor.py b/tests/unit/algorithms/sequence_packing_gradient_actor.py index 1c57975f11..0165c86cc3 100644 --- a/tests/unit/algorithms/sequence_packing_gradient_actor.py +++ b/tests/unit/algorithms/sequence_packing_gradient_actor.py @@ -340,13 +340,12 @@ def forward( pad_full_seq_to=max_seq_len * batch_size if cp_size > 1 else None, ), model=MockModel(), - cfg=cfg, post_processing_fn=post_processor, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, straggler_timer=mock_straggler_timer, ) - loss, metrics = wrapped_loss_fn(output_tensor) + loss, _ = wrapped_loss_fn(output_tensor) loss.backward() diff --git a/tests/unit/models/automodel/test_automodel_train.py b/tests/unit/models/automodel/test_automodel_train.py index 9eb5d72bb9..d04077cdaa 100644 --- a/tests/unit/models/automodel/test_automodel_train.py +++ b/tests/unit/models/automodel/test_automodel_train.py @@ -24,6 +24,7 @@ except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True) +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss.interfaces import LossInputType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( @@ -234,34 +235,38 @@ def test_output_with_last_hidden_state(self, mock_model): # ===================== @pytest.mark.automodel class TestApplyTemperatureScaling: - def test_temperature_scaling_applied(self): + """Tests for apply_temperature_scaling function.""" + + def test_temperature_scaling_sampling_params_is_none(self): + """Test that logits are unchanged when sampling_params is None.""" logits = torch.randn(4, 64, 32000) original_logits = logits.clone() - cfg = {"generation": {"temperature": 2.0}} - result = apply_temperature_scaling(logits, cfg) + result = apply_temperature_scaling(logits, None) - # Should be divided by temperature - expected = original_logits / 2.0 - assert torch.allclose(result, expected) + assert torch.equal(result, original_logits) - def test_no_scaling_without_generation_config(self): + def test_temperature_scaling_with_temperature_one(self): + """Test that temperature=1.0 leaves logits unchanged.""" logits = torch.randn(4, 64, 32000) original_logits = logits.clone() - cfg = {} + sampling_params = TrainingSamplingParams(temperature=1.0) - result = apply_temperature_scaling(logits, cfg) + result = apply_temperature_scaling(logits, sampling_params) assert torch.equal(result, original_logits) - def test_no_scaling_with_none_generation(self): + def test_temperature_scaling_with_temperature_two(self): + """Test that logits are divided by the configured temperature=2.0.""" logits = torch.randn(4, 64, 32000) original_logits = logits.clone() - cfg = {"generation": None} + sampling_params = TrainingSamplingParams(temperature=2.0) - result = apply_temperature_scaling(logits, cfg) + result = apply_temperature_scaling(logits, sampling_params) - assert torch.equal(result, original_logits) + # Should be divided by temperature + expected = original_logits / 2.0 + assert torch.allclose(result, expected) # ===================== @@ -293,7 +298,7 @@ def test_basic_loss_computation( vocab_size = 32000 logits = torch.randn(batch_size, seq_len, vocab_size) - mb = BatchedDataDict( + data_dict = BatchedDataDict( { "input_ids": torch.randint(0, vocab_size, (batch_size, seq_len)), "sample_mask": torch.ones(batch_size, dtype=torch.bool), @@ -302,9 +307,9 @@ def test_basic_loss_computation( global_valid_seqs = torch.tensor(8) global_valid_toks = torch.tensor(512) - loss, metrics = processor( + processor( logits=logits, - mb=mb, + data_dict=data_dict, processed_inputs=processed_inputs_no_flash, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, @@ -349,7 +354,7 @@ def test_loss_with_sequence_packing( vocab_size = 32000 logits = torch.randn(batch_size, seq_len, vocab_size) - mb = BatchedDataDict( + data_dict = BatchedDataDict( { "input_ids": torch.randint(0, vocab_size, (batch_size, seq_len)), "sample_mask": torch.ones(batch_size, dtype=torch.bool), @@ -358,9 +363,9 @@ def test_loss_with_sequence_packing( global_valid_seqs = torch.tensor(4) global_valid_toks = torch.tensor(128) - loss, metrics = processor( + processor( logits=logits, - mb=mb, + data_dict=data_dict, processed_inputs=processed_inputs_with_flash, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, @@ -446,6 +451,7 @@ def test_basic_logprobs_computation( logits = torch.randn(batch_size, seq_len, vocab_size) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) input_lengths = torch.full((batch_size,), seq_len) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) processed_inputs = ProcessedInputs( input_ids=input_ids, @@ -460,8 +466,8 @@ def test_basic_logprobs_computation( result = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=batch_size, original_seq_len=seq_len, ) @@ -487,6 +493,7 @@ def test_logprobs_with_chunking( logits = torch.randn(batch_size, seq_len, vocab_size) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) input_lengths = torch.full((batch_size,), seq_len) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) processed_inputs = ProcessedInputs( input_ids=input_ids, @@ -501,8 +508,8 @@ def test_logprobs_with_chunking( result = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=batch_size, original_seq_len=seq_len, ) @@ -532,6 +539,7 @@ def test_basic_topk(self, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh logits = torch.randn(batch_size, seq_len, vocab_size) input_lengths = torch.full((batch_size,), seq_len) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) processed_inputs = ProcessedInputs( input_ids=torch.randint(0, vocab_size, (batch_size, seq_len)), @@ -546,8 +554,8 @@ def test_basic_topk(self, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh vals, idx = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=batch_size, original_seq_len=seq_len, ) @@ -775,9 +783,8 @@ def test_forward_with_loss_post_processor( ) # Call forward_with_post_processing_fn - result, metrics, returned_mb = forward_with_post_processing_fn( + _, _, returned_mb = forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=loss_post_processor, processed_mb=processed_mb, global_valid_seqs=torch.tensor(batch_size), @@ -839,9 +846,8 @@ def test_forward_with_score_post_processor( score_post_processor = ScorePostProcessor(cfg=base_cfg) # Call forward_with_post_processing_fn - result, metrics, returned_mb = forward_with_post_processing_fn( + result, metrics, _ = forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=score_post_processor, processed_mb=processed_mb, is_reward_model=True, @@ -919,7 +925,6 @@ def test_forward_backward_single_microbatch( # Call automodel_forward_backward in forward_only mode results = automodel_forward_backward( model=mock_model, - cfg=base_cfg, data_iterator=iter([processed_mb]), post_processing_fn=loss_post_processor, forward_only=True, @@ -994,7 +999,6 @@ def test_forward_backward_multiple_microbatches( # Call automodel_forward_backward in forward_only mode results = automodel_forward_backward( model=mock_model, - cfg=base_cfg, data_iterator=iter(processed_mbs), post_processing_fn=loss_post_processor, forward_only=True, @@ -1084,7 +1088,6 @@ def mock_train_context_fn(processed_inputs): # Call automodel_forward_backward with train_context_fn results = automodel_forward_backward( model=mock_model, - cfg=base_cfg, data_iterator=iter([processed_mb]), post_processing_fn=loss_post_processor, forward_only=True, @@ -1164,7 +1167,6 @@ def on_microbatch_start(mb_idx): # Call automodel_forward_backward with on_microbatch_start results = automodel_forward_backward( model=mock_model, - cfg=base_cfg, data_iterator=iter(processed_mbs), post_processing_fn=loss_post_processor, forward_only=True, @@ -1239,7 +1241,6 @@ def test_forward_backward_with_dummy_batches( # Call automodel_forward_backward with num_valid_microbatches results = automodel_forward_backward( model=mock_model, - cfg=base_cfg, data_iterator=iter(processed_mbs), post_processing_fn=loss_post_processor, forward_only=True, @@ -1318,7 +1319,6 @@ def test_forward_with_logprobs_post_processor( # Call forward_with_post_processing_fn result, metrics, returned_mb = forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=logprobs_post_processor, processed_mb=processed_mb, ) @@ -1390,9 +1390,8 @@ def test_forward_with_topk_post_processor( ) # Call forward_with_post_processing_fn - result, metrics, returned_mb = forward_with_post_processing_fn( + result, metrics, _ = forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=topk_post_processor, processed_mb=processed_mb, ) @@ -1459,7 +1458,6 @@ class UnknownPostProcessor: with pytest.raises(TypeError, match="Unknown post-processing function type"): forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=unknown_post_processor, processed_mb=processed_mb, ) @@ -1520,9 +1518,8 @@ def test_forward_with_processed_mb_directly( ) # Call forward_with_post_processing_fn with processed_mb directly (no iterator) - result, metrics, returned_mb = forward_with_post_processing_fn( + _, _, returned_mb = forward_with_post_processing_fn( model=mock_model, - cfg=base_cfg, post_processing_fn=loss_post_processor, processed_mb=processed_mb, # Directly provided global_valid_seqs=torch.tensor(batch_size), @@ -1694,6 +1691,7 @@ def test_logprobs_with_sequence_packing( logits = torch.randn(1, packed_seq_len, vocab_size) input_ids = torch.randint(0, vocab_size, (1, packed_seq_len)) input_lengths = torch.tensor([32, 32, 32, 32]) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) @dataclass class MockFlashAttnKwargs: @@ -1716,8 +1714,8 @@ class MockFlashAttnKwargs: result = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=original_batch_size, original_seq_len=original_seq_len, ) @@ -1743,8 +1741,8 @@ def test_logprobs_masking_without_sequence_packing( logits = torch.randn(batch_size, seq_len, vocab_size) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) - # Variable length sequences input_lengths = torch.tensor([32, 48, 64, 16]) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) processed_inputs = ProcessedInputs( input_ids=input_ids, @@ -1759,8 +1757,8 @@ def test_logprobs_masking_without_sequence_packing( result = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=batch_size, original_seq_len=seq_len, ) @@ -1801,6 +1799,7 @@ def test_topk_with_sequence_packing( logits = torch.randn(1, packed_seq_len, vocab_size) input_lengths = torch.tensor([32, 32, 32, 32]) + data_dict = BatchedDataDict({"input_lengths": input_lengths}) @dataclass class MockFlashAttnKwargs: @@ -1823,8 +1822,8 @@ class MockFlashAttnKwargs: vals, idx = processor( logits=logits, + data_dict=data_dict, processed_inputs=processed_inputs, - input_lengths=input_lengths, original_batch_size=original_batch_size, original_seq_len=original_seq_len, ) @@ -1921,7 +1920,6 @@ def loss_fn(logits, data, global_valid_seqs, global_valid_toks): # Call automodel_forward_backward with forward_only=False results = automodel_forward_backward( model=model, - cfg=base_cfg, data_iterator=iter([processed_mb]), post_processing_fn=loss_post_processor, forward_only=False, # Enable backward pass diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index b80d1a7986..da3c9179f7 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -27,6 +27,7 @@ import pytest import torch +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss.interfaces import LossInputType @@ -48,12 +49,10 @@ def test_model_forward_basic(self): input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) position_ids = torch.tensor([[0, 1, 2], [0, 1, 2]]) attention_mask = torch.ones(2, 3) - cfg = {} result = model_forward( model=mock_model, data_dict=mock_data_dict, - cfg=cfg, input_ids_cp_sharded=input_ids, position_ids=position_ids, attention_mask=attention_mask, @@ -80,7 +79,6 @@ def test_model_forward_with_straggler_timer(self): result = model_forward( model=mock_model, data_dict=mock_data_dict, - cfg={}, input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), position_ids=torch.tensor([[0, 1, 2]]), attention_mask=torch.ones(1, 3), @@ -108,7 +106,6 @@ def test_model_forward_with_packed_seq_params(self): model_forward( model=mock_model, data_dict=mock_data_dict, - cfg={}, input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), position_ids=torch.tensor([[0, 1, 2]]), attention_mask=torch.ones(1, 3), @@ -132,7 +129,6 @@ def test_model_forward_with_defer_fp32_logits(self): model_forward( model=mock_model, data_dict=mock_data_dict, - cfg={}, input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), position_ids=torch.tensor([[0, 1, 2]]), attention_mask=torch.ones(1, 3), @@ -157,7 +153,6 @@ def test_model_forward_clears_position_ids_for_multimodal(self): model_forward( model=mock_model, data_dict=mock_data_dict, - cfg={}, input_ids_cp_sharded=torch.tensor([[1, 2, 3]]), position_ids=torch.tensor([[0, 1, 2]]), attention_mask=torch.ones(1, 3), @@ -170,39 +165,14 @@ def test_model_forward_clears_position_ids_for_multimodal(self): class TestApplyTemperatureScaling: """Tests for apply_temperature_scaling function.""" - def test_temperature_scaling_with_generation_config(self): - """Test that logits are divided by the configured temperature.""" - from nemo_rl.models.megatron.train import apply_temperature_scaling - - logits = torch.ones(2, 10, 100) * 4.0 - cfg = {"generation": {"temperature": 2.0}} - - result = apply_temperature_scaling(logits, cfg) - - # 4.0 / 2.0 = 2.0 - assert torch.allclose(result, torch.ones_like(result) * 2.0) - # Verify in-place: result is the same tensor - assert result.data_ptr() == logits.data_ptr() - - def test_temperature_scaling_no_generation_key(self): - """Test that logits are unchanged when 'generation' key is absent.""" - from nemo_rl.models.megatron.train import apply_temperature_scaling - - logits = torch.ones(2, 10, 100) * 3.0 - cfg = {} - - result = apply_temperature_scaling(logits, cfg) - - assert torch.allclose(result, torch.ones_like(result) * 3.0) - - def test_temperature_scaling_generation_is_none(self): - """Test that logits are unchanged when generation config is None.""" + def test_temperature_scaling_sampling_params_is_none(self): + """Test that logits are unchanged when sampling_params is None.""" from nemo_rl.models.megatron.train import apply_temperature_scaling logits = torch.ones(2, 10, 100) * 3.0 - cfg = {"generation": None} + sampling_params = None - result = apply_temperature_scaling(logits, cfg) + result = apply_temperature_scaling(logits, sampling_params) assert torch.allclose(result, torch.ones_like(result) * 3.0) @@ -212,12 +182,26 @@ def test_temperature_scaling_with_temperature_one(self): logits = torch.randn(2, 10, 100) original = logits.clone() - cfg = {"generation": {"temperature": 1.0}} + sampling_params = TrainingSamplingParams(temperature=1.0) - result = apply_temperature_scaling(logits, cfg) + result = apply_temperature_scaling(logits, sampling_params) assert torch.allclose(result, original) + def test_temperature_scaling_with_temperature_two(self): + """Test that logits are divided by the configured temperature=2.0.""" + from nemo_rl.models.megatron.train import apply_temperature_scaling + + logits = torch.ones(2, 10, 100) * 4.0 + sampling_params = TrainingSamplingParams(temperature=2.0) + + result = apply_temperature_scaling(logits, sampling_params) + + # 4.0 / 2.0 = 2.0 + assert torch.allclose(result, torch.ones_like(result) * 2.0) + # Verify in-place: result is the same tensor + assert result.data_ptr() == logits.data_ptr() + class TestForwardWithPostProcessingFn: """Tests for forward_with_post_processing_fn function.""" @@ -268,7 +252,6 @@ def test_forward_with_loss_post_processor( output, wrapped_fn = forward_with_post_processing_fn( data_iterator=data_iterator, model=mock_model, - cfg=cfg, post_processing_fn=post_processor, ) @@ -304,10 +287,9 @@ def test_forward_with_logprobs_post_processor(self, mock_model_forward): post_processor = LogprobsPostProcessor(cfg=cfg) with patch.object(post_processor, "__call__", return_value=MagicMock()): - output, wrapped_fn = forward_with_post_processing_fn( + forward_with_post_processing_fn( data_iterator=data_iterator, model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, ) @@ -342,10 +324,9 @@ def test_forward_with_topk_post_processor(self, mock_model_forward): post_processor = TopkLogitsPostProcessor(cfg=cfg, k=5) with patch.object(post_processor, "__call__", return_value=MagicMock()): - output, wrapped_fn = forward_with_post_processing_fn( + forward_with_post_processing_fn( data_iterator=data_iterator, model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, ) @@ -398,16 +379,19 @@ def test_forward_applies_temperature_scaling_for_loss( "generation": {"temperature": 0.7}, } post_processor = LossPostProcessor(loss_fn=MagicMock(), cfg=cfg) + sampling_params = TrainingSamplingParams( + temperature=cfg["generation"]["temperature"] + ) forward_with_post_processing_fn( data_iterator=iter([processed_mb]), model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, + sampling_params=sampling_params, ) # Verify apply_temperature_scaling was called with the output tensor and cfg - mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + mock_temp_scaling.assert_called_once_with(output_tensor, sampling_params) @patch("nemo_rl.models.megatron.train.model_forward") @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") @@ -439,16 +423,19 @@ def test_forward_applies_temperature_scaling_for_logprobs( "generation": {"temperature": 0.5}, } post_processor = LogprobsPostProcessor(cfg=cfg) + sampling_params = TrainingSamplingParams( + temperature=cfg["generation"]["temperature"] + ) with patch.object(post_processor, "__call__", return_value=MagicMock()): forward_with_post_processing_fn( data_iterator=iter([processed_mb]), model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, + sampling_params=sampling_params, ) - mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + mock_temp_scaling.assert_called_once_with(output_tensor, sampling_params) @patch("nemo_rl.models.megatron.train.model_forward") @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") @@ -481,16 +468,18 @@ def test_forward_applies_temperature_scaling_for_topk( "generation": {"temperature": 1.5}, } post_processor = TopkLogitsPostProcessor(cfg=cfg, k=5) - + sampling_params = TrainingSamplingParams( + temperature=cfg["generation"]["temperature"] + ) with patch.object(post_processor, "__call__", return_value=MagicMock()): forward_with_post_processing_fn( data_iterator=iter([processed_mb]), model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, + sampling_params=sampling_params, ) - mock_temp_scaling.assert_called_once_with(output_tensor, cfg) + mock_temp_scaling.assert_called_once_with(output_tensor, sampling_params) @patch("nemo_rl.models.megatron.train.model_forward") @patch("nemo_rl.models.megatron.train.apply_temperature_scaling") @@ -517,7 +506,6 @@ def test_forward_does_not_apply_temperature_scaling_for_unknown_type( forward_with_post_processing_fn( data_iterator=iter([processed_mb]), model=MagicMock(), - cfg={"generation": {"temperature": 2.0}}, post_processing_fn="not_a_processor", ) @@ -563,7 +551,6 @@ def test_forward_with_straggler_timer( forward_with_post_processing_fn( data_iterator=iter([processed_mb]), model=MagicMock(), - cfg=cfg, post_processing_fn=post_processor, straggler_timer=mock_timer, ) @@ -597,7 +584,6 @@ def test_forward_with_unknown_post_processor_raises(self, mock_model_forward): forward_with_post_processing_fn( data_iterator=data_iterator, model=MagicMock(), - cfg={}, post_processing_fn=unknown_processor, ) @@ -621,9 +607,8 @@ def test_megatron_forward_backward_calls_forward_backward_func(self, mock_get_fb cfg = {"sequence_packing": {"enabled": False}} post_processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg) - result = megatron_forward_backward( + megatron_forward_backward( model=mock_model, - cfg=cfg, data_iterator=iter([]), num_microbatches=4, seq_length=128, @@ -656,7 +641,6 @@ def test_megatron_forward_backward_forward_only(self, mock_get_fb): megatron_forward_backward( model=MagicMock(), - cfg=cfg, data_iterator=iter([]), num_microbatches=1, seq_length=64, From 0b333e39b3f6eb3d9c41e35b42b016eaaa931dc0 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 07:47:57 -0800 Subject: [PATCH 11/20] update nightly test Signed-off-by: Yuki Huang --- ...-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.yaml | 14 +++++++ ....2-1b-instruct-1n8g-megatron-sampling.yaml | 37 ------------------- ...3.2-1b-instruct-1n8g-megatron-temp0.6.yaml | 35 ------------------ ...-1n8g-megatron-temp0.8-topp0.9-topk50.yaml | 15 ++++++++ ...t-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh} | 4 +- ...t-1n8g-megatron-temp0.8-topp0.9-topk50.sh} | 4 +- tests/test_suites/nightly.txt | 6 ++- 7 files changed, 37 insertions(+), 78 deletions(-) create mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.yaml delete mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml delete mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml create mode 100644 examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.yaml rename tests/test_suites/llm/{grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh => grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh} (90%) rename tests/test_suites/llm/{grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh => grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh} (90%) diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.yaml new file mode 100644 index 0000000000..ec2765c2f2 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.yaml @@ -0,0 +1,14 @@ +defaults: grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +checkpointing: + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50 +policy: + dtensor_cfg: + tensor_parallel_size: 2 + generation: + temperature: 0.8 + top_p: 0.9 + top_k: 50 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50 + wandb: + name: grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml deleted file mode 100644 index 34d76a7421..0000000000 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml +++ /dev/null @@ -1,37 +0,0 @@ -defaults: ../../grpo_math_1B.yaml -grpo: - max_num_steps: 500 -checkpointing: - enabled: false - checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron - save_period: 100 -policy: - model_name: meta-llama/Llama-3.2-1B-Instruct - tokenizer: - name: meta-llama/Llama-3.2-1B-Instruct - optimizer: null - megatron_cfg: - enabled: true - scheduler: - lr_warmup_iters: 50 - dtensor_cfg: - enabled: false - make_sequence_length_divisible_by: 1 - generation: - max_new_tokens: 512 - vllm_cfg: - max_model_len: 512 - temperature: 0.8 - top_p: 0.9 - top_k: 50 -data: - max_input_seq_length: 512 -logger: - log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron - wandb_enabled: true - tensorboard_enabled: true - wandb: - project: nemo-rl - name: grpo-llama3.2-1b-instruct-1n8g-megatron -cluster: - gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml deleted file mode 100644 index 82b5639f73..0000000000 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml +++ /dev/null @@ -1,35 +0,0 @@ -defaults: ../../grpo_math_1B.yaml -grpo: - max_num_steps: 500 -checkpointing: - enabled: false - checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron - save_period: 100 -policy: - model_name: meta-llama/Llama-3.2-1B-Instruct - tokenizer: - name: meta-llama/Llama-3.2-1B-Instruct - optimizer: null - megatron_cfg: - enabled: true - scheduler: - lr_warmup_iters: 50 - dtensor_cfg: - enabled: false - make_sequence_length_divisible_by: 1 - generation: - max_new_tokens: 512 - vllm_cfg: - max_model_len: 512 - temperature: 0.6 -data: - max_input_seq_length: 512 -logger: - log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron - wandb_enabled: true - tensorboard_enabled: true - wandb: - project: nemo-rl - name: grpo-llama3.2-1b-instruct-1n8g-megatron -cluster: - gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.yaml new file mode 100644 index 0000000000..f70020ffd4 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.yaml @@ -0,0 +1,15 @@ +defaults: grpo-llama3.2-1b-instruct-1n8g-megatron.yaml +checkpointing: + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50 +policy: + megatron_cfg: + tensor_model_parallel_size: 2 + make_sequence_length_divisible_by: 2 + generation: + temperature: 0.8 + top_p: 0.9 + top_k: 50 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50 + wandb: + name: grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50 diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh similarity index 90% rename from tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh rename to tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh index 1cdfad6fd3..ce065f181b 100755 --- a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh @@ -34,10 +34,10 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS # Only run metrics if the target step is reached if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then uv run tests/check_metrics.py $JSON_METRICS \ - 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ 'data["train/token_mult_prob_error"]["500"] < 1.1' \ 'data["train/reward"]["500"] > 0.1' \ - 'mean(data["timing/train/total_step_time"], -6, -1) < 12.5' + 'mean(data["timing/train/total_step_time"], -6, -1) < 20' # Clean up checkpoint directory after successful run to save space. rm -rf "$CKPT_DIR" diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh similarity index 90% rename from tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh rename to tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh index 21dd483626..ce065f181b 100755 --- a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh @@ -34,10 +34,10 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS # Only run metrics if the target step is reached if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then uv run tests/check_metrics.py $JSON_METRICS \ - 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ 'data["train/token_mult_prob_error"]["500"] < 1.1' \ 'data["train/reward"]["500"] > 0.1' \ - 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' + 'mean(data["timing/train/total_step_time"], -6, -1) < 20' # Clean up checkpoint directory after successful run to save space. rm -rf "$CKPT_DIR" diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index bd05ad6def..072945331e 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -23,8 +23,6 @@ tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh -tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh -tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh @@ -80,6 +78,10 @@ tests/test_suites/llm/grpo-nanov3-30BA3B-2n8g-megatron-lora.sh tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-megatron-lora.sh +# Sampling (temperature, top-p, top-k) +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp2-temp0.8-topp0.9-topk50.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh + ####### # SFT # ####### From eeddf3294d7a2921e91e4aaa2ae684cbdc864d31 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 21:28:40 -0800 Subject: [PATCH 12/20] make megatron worker respect to make_sequence_length_divisible_by Signed-off-by: Yuki Huang --- nemo_rl/models/megatron/data.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index 13daee1352..9a68b524c2 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -175,6 +175,7 @@ def get_microbatch_iterator( pad_full_seq_to, ) = _get_pack_sequence_parameters_for_megatron( cfg["megatron_cfg"], + cfg["make_sequence_length_divisible_by"], pack_seq_dim_size, ) micro_batch_size = 1 @@ -529,12 +530,14 @@ def _pack_sequences_for_megatron( def _get_pack_sequence_parameters_for_megatron( megatron_cfg: dict, + pad_individual_seqs_to_multiple_of: int, max_seq_len_in_batch: int, ): """Get pack sequence parameters for Megatron model processing with optional context parallelism. Args: megatron_cfg: Megatron configuration + pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value max_seq_len_in_batch: Maximum sequence length in batch Returns: @@ -551,11 +554,15 @@ def _get_pack_sequence_parameters_for_megatron( use_fp8 = fp8_cfg.get("enabled", False) # individual sequence needs to be splitted to CP domain, and to TP domain when SP is enabled. - pad_individual_seqs_to_multiple_of = 1 + minimum_pad_factor = 1 if cp_size > 1: - pad_individual_seqs_to_multiple_of *= cp_size * 2 + minimum_pad_factor *= cp_size * 2 if tp_size > 1 and sp: - pad_individual_seqs_to_multiple_of *= tp_size + minimum_pad_factor *= tp_size + assert pad_individual_seqs_to_multiple_of % minimum_pad_factor == 0, ( + f"make_sequence_length_divisible_by ({pad_individual_seqs_to_multiple_of}) is not a multiple of minimum_pad_factor ({minimum_pad_factor}). " + "Please set policy.make_sequence_length_divisible_by to a multiple of the minimum_pad_factor." + ) # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. if use_fp8: From 3cc9cfa048b13e2d46bd128df481a9fc58657d94 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Mar 2026 08:11:46 -0800 Subject: [PATCH 13/20] fix unit test Signed-off-by: Yuki Huang --- tests/unit/models/megatron/test_megatron_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 7b2a5d3622..793faffc44 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -751,7 +751,6 @@ def test_reward_model_not_supported(self): hf_model_name="test-model", pretrained_path="/path/to/model", weights_path=None, - tokenizer=MagicMock(), ) assert "Reward models are not yet supported" in str(exc_info.value) @@ -794,7 +793,6 @@ def test_generation_colocation_detection(self): hf_model_name="test-model", pretrained_path="/path/to/model", weights_path=None, - tokenizer=MagicMock(), ) assert runtime_config.is_generation_colocated is True @@ -815,6 +813,7 @@ def test_runtime_config_fields(self): optimizer_cpu_offload=False, offload_optimizer_for_logprob=True, is_generation_colocated=True, + sampling_params=None, final_padded_vocab_size=32000, ) @@ -822,6 +821,7 @@ def test_runtime_config_fields(self): assert runtime_config.optimizer_cpu_offload is False assert runtime_config.offload_optimizer_for_logprob is True assert runtime_config.is_generation_colocated is True + assert runtime_config.sampling_params is None assert runtime_config.final_padded_vocab_size == 32000 From 120594ab416503f6b647ee9ba647edfd8c9b087b Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 5 Mar 2026 19:40:27 -0800 Subject: [PATCH 14/20] limit nightly test threshold Signed-off-by: Yuki Huang --- ...ama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh index ce065f181b..d50118b9fc 100755 --- a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.8-topp0.9-topk50.sh @@ -7,7 +7,7 @@ NUM_NODES=1 STEPS_PER_RUN=500 MAX_STEPS=500 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=180 +NUM_MINUTES=150 # ===== END CONFIG ===== exit_if_max_steps_reached @@ -37,7 +37,7 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma 'median(data["train/token_mult_prob_error"]) < 1.1' \ 'data["train/token_mult_prob_error"]["500"] < 1.1' \ 'data["train/reward"]["500"] > 0.1' \ - 'mean(data["timing/train/total_step_time"], -6, -1) < 20' + 'mean(data["timing/train/total_step_time"], -6, -1) < 15' # Clean up checkpoint directory after successful run to save space. rm -rf "$CKPT_DIR" From 2283932b60f281d5c033809177f5117eeee40c6f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 5 Mar 2026 19:58:37 -0800 Subject: [PATCH 15/20] remove default value and update top_k check Signed-off-by: Yuki Huang --- nemo_rl/algorithms/logits_sampling_utils.py | 2 +- nemo_rl/models/automodel/setup.py | 6 +++--- nemo_rl/models/megatron/setup.py | 6 +++--- nemo_rl/models/policy/workers/dtensor_policy_worker.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/nemo_rl/algorithms/logits_sampling_utils.py b/nemo_rl/algorithms/logits_sampling_utils.py index db6fd43598..d904515934 100644 --- a/nemo_rl/algorithms/logits_sampling_utils.py +++ b/nemo_rl/algorithms/logits_sampling_utils.py @@ -46,7 +46,7 @@ class TrainingSamplingParams: def _need_top_k_filtering(top_k: int | None) -> bool: """Check if top-k filtering is needed.""" - return top_k is not None and top_k != -1 + return top_k is not None and top_k > 0 def _need_top_p_filtering(top_p: float | None) -> bool: diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index f2c36f7e97..ad90c78c67 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -80,9 +80,9 @@ def validate_and_prepare_config( is_generation_colocated = generation_cfg["colocated"]["enabled"] # set sampling params sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), + top_k=generation_cfg["top_k"], + top_p=generation_cfg["top_p"], + temperature=generation_cfg["temperature"], ) # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 0b584cab90..2e8c2c3bfe 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -205,9 +205,9 @@ def validate_and_set_config( is_generation_colocated = generation_cfg["colocated"]["enabled"] # set sampling params sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), + top_k=generation_cfg["top_k"], + top_p=generation_cfg["top_p"], + temperature=generation_cfg["temperature"], ) # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 7234ec2913..a50bda53df 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -182,9 +182,9 @@ def __init__( self.is_generation_colocated = generation_cfg["colocated"]["enabled"] # set sampling params self.sampling_params = TrainingSamplingParams( - top_k=generation_cfg.get("top_k", None), - top_p=generation_cfg.get("top_p", 1.0), - temperature=generation_cfg.get("temperature", 1.0), + top_k=generation_cfg["top_k"], + top_p=generation_cfg["top_p"], + temperature=generation_cfg["temperature"], ) # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. From 793088f6aae9de8fb638f46cffc0e7267ef61ad8 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 5 Mar 2026 22:14:35 -0800 Subject: [PATCH 16/20] address comments Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss/loss_functions.py | 4 +- nemo_rl/algorithms/utils.py | 1 + nemo_rl/models/megatron/data.py | 7 +- .../policy/workers/dtensor_policy_worker.py | 77 ++++++++------- .../workers/dtensor_policy_worker_v2.py | 66 +++++++------ .../policy/workers/megatron_policy_worker.py | 94 +++++++++---------- 6 files changed, 126 insertions(+), 123 deletions(-) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index 7244d868c7..f25e1a22b2 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -287,13 +287,13 @@ def __call__( if self.use_on_policy_kl_approximation: # See: docs/guides/grpo.md#on-policy-kl-approximation kl_importance_weights = torch.exp( - curr_logprobs_unfiltered - generation_logprobs + curr_logprobs - generation_logprobs ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) else: - kl_importance_weights = torch.ones_like(curr_logprobs_unfiltered) + kl_importance_weights = torch.ones_like(curr_logprobs) # Compute KL loss kl = ( diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index c48d3b7937..17b33935e9 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -196,6 +196,7 @@ def mask_out_neg_inf_logprobs( Args: logprobs: Log probabilities. mask: Mask. + logprobs_name: Name of the logprobs tensor. Used for printing warning messages. Returns: Masked log probabilities. diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index 9a68b524c2..fd68ce2fbc 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -560,8 +560,11 @@ def _get_pack_sequence_parameters_for_megatron( if tp_size > 1 and sp: minimum_pad_factor *= tp_size assert pad_individual_seqs_to_multiple_of % minimum_pad_factor == 0, ( - f"make_sequence_length_divisible_by ({pad_individual_seqs_to_multiple_of}) is not a multiple of minimum_pad_factor ({minimum_pad_factor}). " - "Please set policy.make_sequence_length_divisible_by to a multiple of the minimum_pad_factor." + f"make_sequence_length_divisible_by ({pad_individual_seqs_to_multiple_of}) is not a multiple of minimum_pad_factor ({minimum_pad_factor}).\n" + f"Please set policy.make_sequence_length_divisible_by to a multiple of {minimum_pad_factor}.\n" + f" - If CP is enabled, the minimum pad factor is `cp_size * 2`.\n" + f" - If TP+SP is enabled, the minimum pad factor is `tp_size`.\n" + f" - If both are enabled, the minimum pad factor is `cp_size * 2 * tp_size`." ) # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index a50bda53df..1e6752ac41 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -19,6 +19,7 @@ import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext +from functools import partial from typing import Any, Generator, Iterable, Optional, Set, Union, cast import ray @@ -801,11 +802,15 @@ def train( placements=[Shard(sequence_dim), Shard(-1)], ) + # Wrap prepare_loss_input with sampling_params + prepare_loss_input_wrapped = partial( + prepare_loss_input, sampling_params=self.sampling_params + ) # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=loss_fn, - prepare_fn=prepare_loss_input, + prepare_fn=prepare_loss_input_wrapped, cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, ) @@ -816,7 +821,9 @@ def train( global_valid_toks, ) else: - loss_input, mb = prepare_loss_input(logits, mb, loss_fn) + loss_input, mb = prepare_loss_input_wrapped( + logits, mb, loss_fn + ) loss, loss_metrics = loss_fn( data=mb, global_valid_seqs=global_valid_seqs, @@ -1662,44 +1669,42 @@ def use_reference_model(self) -> Generator[None, None, None]: On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ with torch.no_grad(): - try: - # Save train model state_dict - curr_state_dict = get_cpu_state_dict( - self.model.state_dict().items(), pin_memory=True - ) + # Save train model state_dict + curr_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) - # Swap reference model state_dict to self.model - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(self.reference_model_state_dict[k]) - - # Temporarily disable top-k/top-p filtering for reference policy logprobs. - # The reference policy has different weights, so its top-k/top-p set is - # inherently different from the current policy. Using filtered logprobs - # would cause -inf mismatches that cannot be resolved by masking. - # Note: We keep temperature scaling since it was applied to prev_logprobs. - saved_sampling_params = self.sampling_params - if saved_sampling_params is not None: - self.sampling_params = TrainingSamplingParams( - top_k=None, # Disable top-k - top_p=1.0, # Disable top-p - temperature=saved_sampling_params.temperature, # Keep temperature - ) - else: - self.sampling_params = None + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(self.reference_model_state_dict[k]) + + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, # Disable top-k + top_p=1.0, # Disable top-p + temperature=saved_sampling_params.temperature, # Keep temperature + ) + else: + self.sampling_params = None - # - self.model is the original reference_model, now on CUDA - # - curr_state_dict is the train model, now on CPU - yield + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU + yield - finally: - # Restore sampling_params - self.sampling_params = saved_sampling_params + # Restore sampling_params + self.sampling_params = saved_sampling_params - # Restore train model state_dict - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(curr_state_dict[k]) + # Restore train model state_dict + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(curr_state_dict[k]) def _add_noise_to_weights(self) -> None: """Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.""" diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 122598e64c..86cd83993c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -787,44 +787,42 @@ def use_reference_model(self) -> Generator[None, None, None]: On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ with torch.no_grad(): - try: - # Save train model state_dict - curr_state_dict = get_cpu_state_dict( - self.model.state_dict().items(), pin_memory=True - ) + # Save train model state_dict + curr_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) - # Swap reference model state_dict to self.model - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(self.reference_model_state_dict[k]) - - # Temporarily disable top-k/top-p filtering for reference policy logprobs. - # The reference policy has different weights, so its top-k/top-p set is - # inherently different from the current policy. Using filtered logprobs - # would cause -inf mismatches that cannot be resolved by masking. - # Note: We keep temperature scaling since it was applied to prev_logprobs. - saved_sampling_params = self.sampling_params - if saved_sampling_params is not None: - self.sampling_params = TrainingSamplingParams( - top_k=None, - top_p=1.0, - temperature=saved_sampling_params.temperature, - ) - else: - self.sampling_params = None + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(self.reference_model_state_dict[k]) + + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None - # - self.model is the original reference_model, now on CUDA - # - curr_state_dict is the train model, now on CPU - yield + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU + yield - finally: - # Restore sampling_params - self.sampling_params = saved_sampling_params + # Restore sampling_params + self.sampling_params = saved_sampling_params - # Restore train model state_dict - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(curr_state_dict[k]) + # Restore train model state_dict + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(curr_state_dict[k]) def _add_noise_to_weights(self) -> None: """Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.""" diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 073b6bd9de..a10fd8ce6e 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -540,60 +540,56 @@ def use_reference_model(self): self.disable_forward_pre_hook() with torch.no_grad(): - try: - # Save original references - model_state_dict = {} - for name, item in self.model.state_dict().items(): - if isinstance(item, torch.Tensor): - item = item.detach().to( - device="cpu", non_blocking=True, copy=True - ) - model_state_dict[name] = item - - # Swap reference model state_dict to self.model - for k, v in self.model.state_dict().items(): - if isinstance(v, torch.Tensor): - v.copy_(self.reference_state_dict[k]) - - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() - - # Temporarily disable top-k/top-p filtering for reference policy logprobs. - # The reference policy has different weights, so its top-k/top-p set is - # inherently different from the current policy. Using filtered logprobs - # would cause -inf mismatches that cannot be resolved by masking. - # Note: We keep temperature scaling since it was applied to prev_logprobs. - saved_sampling_params = self.sampling_params - if saved_sampling_params is not None: - self.sampling_params = TrainingSamplingParams( - top_k=None, - top_p=1.0, - temperature=saved_sampling_params.temperature, - ) - else: - self.sampling_params = None + # Save original references + model_state_dict = {} + for name, item in self.model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device="cpu", non_blocking=True, copy=True) + model_state_dict[name] = item + + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + if isinstance(v, torch.Tensor): + v.copy_(self.reference_state_dict[k]) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + gc.collect() + torch.cuda.empty_cache() + + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None - # - self.model is the original reference_model, now on CUDA - # - self.reference_model is the original model, now on CPU - yield + # - self.model is the original reference_model, now on CUDA + # - self.reference_model is the original model, now on CPU + yield - finally: - # Restore sampling_params - self.sampling_params = saved_sampling_params + # Restore sampling_params + self.sampling_params = saved_sampling_params - # Restore original references and device placement - for k, v in self.model.state_dict().items(): - if isinstance(v, torch.Tensor): - v.copy_(model_state_dict[k]) + # Restore original references and device placement + for k, v in self.model.state_dict().items(): + if isinstance(v, torch.Tensor): + v.copy_(model_state_dict[k]) - if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: - gc.collect() - torch.cuda.empty_cache() + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + gc.collect() + torch.cuda.empty_cache() - ## re-enable overlap param gather after weight swap - if self.should_disable_forward_pre_hook: - self.enable_forward_pre_hook() + ## re-enable overlap param gather after weight swap + if self.should_disable_forward_pre_hook: + self.enable_forward_pre_hook() @wrap_with_nvtx_name("megatron_policy_worker/get_topk_logits") def get_topk_logits( From 77e42338dd7346efe0ac38e400ca3749978960bb Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 8 Mar 2026 23:42:01 -0700 Subject: [PATCH 17/20] add doc Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 18 ++++++++++++++++++ nemo_rl/algorithms/loss/loss_functions.py | 6 +++--- .../models/generation/vllm/vllm_generation.py | 6 +++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2e7b410de9..88d1a01cad 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -338,6 +338,12 @@ We support vLLM through the [VllmGeneration](../../nemo_rl/models/generation/vll The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core GRPO training loop. +### Generation Sampling Parameters (temperature, top-p, top-k) + +GRPO uses temperature, top-p (nucleus sampling), and top-k sampling during rollout generation via vLLM; these settings are aligned with the training. For a detailed description of top-p and top-k filtering, see [Top-p and top-k filtering](#top-p-and-top-k-filtering) below. + +**Known issue (Qwen models):** For some Qwen-based models, a `ValueError: Token id 151708 is out of vocabulary` error may occur when the policy drifts from its initial distribution. Setting `top_p` to `0.9999` in the generation config is a recommended workaround. For details and discussion, see [#237](https://github.com/NVIDIA-NeMo/RL/issues/237). + ## Performance Optimizations RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md). @@ -450,6 +456,18 @@ grpo: Set `overlong_filtering` to true when training on tasks where truncation at the maximum sequence length is expected, such as long-form reasoning or mathematical proofs. +#### Top-p and top-k filtering + +The implementation aligns with vLLM’s top-p and top-k filtering by applying an equivalent process to the logits. + +When top-p or top-k filtering is enabled, the following conventions apply: + +- **`curr_logprobs` and `prev_logprobs`** are computed *with* filtering applied, for compatibility with the actor loss. +- **`reference_policy_logprobs`** is computed *without* filtering (see the `use_reference_model` in the policy worker). +- **KL divergence** uses `curr_logprobs_unfiltered`(`curr_logprobs` *without* filtering) so that it is consistent with the reference policy logprobs. + +Under tensor parallelism (TP), enabling top-p or top-k adds communication overhead. The vocabulary is sharded across GPUs (vocab-parallel), while top-p and top-k require full-vocabulary probabilities. A naive all-gather of logits would require large additional memory. The implementation therefore switches to a batch–sequence-parallel layout via all-to-all communication, applies filtering over the full vocabulary, then switches back, avoiding materialization of the full vocabulary on any single rank. + ## Metrics This feature is controlled by the parameters `wandb_name` and `tb_name`. We track a few metrics during training for scientific experimentation and to validate correctness as the run progresses. diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index f25e1a22b2..c72269eee1 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -280,20 +280,20 @@ def __call__( # - reference_policy_logprobs is computed **without** filtering (see use_reference_model) # - curr_logprobs/prev_logprobs are computed **with** filtering (for actor loss compatibility) # - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs - # - For importance weights, we also use unfiltered curr_logprobs_for_kl since we're + # - For importance weights, we also use unfiltered curr_logprobs_unfiltered since we're # reweighting samples from π_gen_filtered to π_curr_unfiltered # On-policy KL approximation if self.use_on_policy_kl_approximation: # See: docs/guides/grpo.md#on-policy-kl-approximation kl_importance_weights = torch.exp( - curr_logprobs - generation_logprobs + curr_logprobs_unfiltered - generation_logprobs ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) else: - kl_importance_weights = torch.ones_like(curr_logprobs) + kl_importance_weights = torch.ones_like(curr_logprobs_unfiltered) # Compute KL loss kl = ( diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index dfee1190c3..0faaad17a1 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -86,14 +86,14 @@ def __init__( ) # Validate sampling parameters early to avoid resource allocation with unsupported configs. - top_k: int | None = self.cfg.get("top_k", None) + top_k: int | None = self.cfg["top_k"] if top_k is not None and top_k != -1 and top_k < 1: raise ValueError( f"top_k valid values: i) None or -1: no filtering. ii) >= 1: top-k filtering. Got top_k={top_k}." ) - top_p: float = self.cfg.get("top_p", 1.0) - if top_p <= 0: + top_p: float = self.cfg["top_p"] + if top_p <= 0 or top_p > 1.0: raise ValueError( f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." ) From 2a7202a5469f4a668f9ba613af6815d8b5e73033 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Mar 2026 01:32:53 -0700 Subject: [PATCH 18/20] add functional test Signed-off-by: Yuki Huang --- tests/functional/L1_Functional_Tests_GPU.sh | 1 + tests/functional/grpo_topp_topk.sh | 49 +++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100755 tests/functional/grpo_topp_topk.sh diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index ac99d767e5..08dad5d885 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -59,6 +59,7 @@ run_test uv run --no-sync bash ./tests/functional/grpo_multiturn.sh run_test uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh run_test uv run --no-sync bash ./tests/functional/grpo_rm_env.sh run_test uv run --no-sync bash ./tests/functional/grpo_sglang.sh +run_test fast uv run --no-sync bash ./tests/functional/grpo_topp_topk.sh run_test uv run --no-sync bash ./tests/functional/prorlv2.sh run_test uv run --no-sync bash ./tests/functional/rm.sh run_test fast uv run --no-sync bash ./tests/functional/sft.sh diff --git a/tests/functional/grpo_topp_topk.sh b/tests/functional/grpo_topp_topk.sh new file mode 100755 index 0000000000..8bb635f07c --- /dev/null +++ b/tests/functional/grpo_topp_topk.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + policy.generation.temperature=0.8 \ + policy.generation.top_p=0.9 \ + policy.generation.top_k=50 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ + 'max(data["train/gen_kl_error"]) < 0.03' \ + 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \ + 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_max"]) < 1.21' From 4b020fce4088a42b5c2b783791bf1508f55b01b1 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sat, 7 Mar 2026 22:48:51 -0800 Subject: [PATCH 19/20] fix unit test Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_distillation.py | 12 + tests/unit/algorithms/test_grpo.py | 18 ++ tests/unit/algorithms/test_utils.py | 5 +- .../models/automodel/test_automodel_setup.py | 15 +- .../models/automodel/test_automodel_train.py | 2 +- .../models/megatron/megatron_data_actors.py | 261 +++++++----------- .../models/megatron/test_megatron_data.py | 22 +- .../models/megatron/test_megatron_setup.py | 3 + tests/unit/models/megatron/test_train.py | 6 +- .../unit/models/policy/test_dtensor_worker.py | 2 +- .../models/policy/test_dtensor_worker_v2.py | 2 +- .../models/policy/test_megatron_worker.py | 6 +- .../models/policy/test_policy_validation.py | 4 +- tests/unit/utils/test_native_checkpoint.py | 4 +- tools/refit_verifier.py | 4 +- 15 files changed, 184 insertions(+), 182 deletions(-) diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index 1bcbe2e2bd..9fd94f23a2 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -141,6 +141,9 @@ def val_iter(self): "make_sequence_length_divisible_by": 8, "max_total_sequence_length": 2048, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "colocated": { "enabled": False, }, @@ -420,6 +423,9 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node(): master_config = { "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": { "enabled": False, # Non-colocated @@ -485,6 +491,9 @@ def test_distillation_setup_non_colocated_smoke(monkeypatch): master_config = { "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": { "enabled": False, @@ -600,6 +609,9 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): master_config = { "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": { "enabled": False, # Non-colocated diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 7a0783f132..32ca0be5c6 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -719,6 +719,9 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node(): master_config = { "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": { "enabled": False, # Non-colocated @@ -798,6 +801,9 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): master_config = { "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": { "enabled": False, # Non-colocated @@ -966,6 +972,9 @@ def init_collective(self, *_args, **_kwargs): "dtensor_cfg": {"enabled": False}, "megatron_cfg": {"enabled": False, "pipeline_model_parallel_size": 1}, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "sglang", "colocated": { "enabled": colocated_inference, @@ -1338,6 +1347,9 @@ def val_iter(self): "max_total_sequence_length": 2048, "make_sequence_length_divisible_by": 1, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": {"enabled": True}, "vllm_cfg": {"async_engine": True}, # Support async mode @@ -1902,6 +1914,9 @@ def capture_log(data, filename): "policy": { "max_total_sequence_length": 2048, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": {"enabled": True}, "vllm_cfg": {"async_engine": False}, @@ -1995,6 +2010,9 @@ def test_validate_works_without_logger(self): "policy": { "max_total_sequence_length": 2048, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "backend": "vllm", "colocated": {"enabled": True}, "vllm_cfg": {"async_engine": False}, diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index edc8d0a812..4560faabd0 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -228,10 +228,13 @@ def _base_master_config(colocated: bool): "cluster": {"num_nodes": 2, "gpus_per_node": 8}, "policy": { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "colocated": { "enabled": colocated, "resources": {"num_nodes": 1, "gpus_per_node": 8}, - } + }, } }, "grpo": {"num_prompts_per_step": 8, "num_generations_per_prompt": 10}, diff --git a/tests/unit/models/automodel/test_automodel_setup.py b/tests/unit/models/automodel/test_automodel_setup.py index 619a71dc1b..8a5864b80c 100644 --- a/tests/unit/models/automodel/test_automodel_setup.py +++ b/tests/unit/models/automodel/test_automodel_setup.py @@ -56,7 +56,12 @@ def mock_config(): "use_hf_tp_plan": False, "activation_checkpointing": False, }, - "generation": None, + "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "colocated": {"enabled": True}, + }, "hf_config_overrides": {}, "optimizer": { "name": "torch.optim.AdamW", @@ -325,7 +330,7 @@ def test_generation_colocated( mock_resolve_class.return_value = Mock # Test with generation colocated enabled - mock_config["generation"] = {"colocated": {"enabled": True}} + mock_config["generation"]["colocated"]["enabled"] = True result = validate_and_prepare_config(mock_config, None, 0) assert result.is_generation_colocated is True # NCCL_CUMEM_ENABLE should not be set when colocated @@ -348,7 +353,7 @@ def test_generation_not_colocated( mock_resolve_class.return_value = Mock # Test with generation colocated disabled - mock_config["generation"] = {"colocated": {"enabled": False}} + mock_config["generation"]["colocated"]["enabled"] = False result = validate_and_prepare_config(mock_config, None, 0) assert result.is_generation_colocated is False # NCCL_CUMEM_ENABLE should be set when not colocated @@ -605,6 +610,7 @@ def mock_runtime_config(self): cpu_offload=False, offload_optimizer_for_logprob=False, is_generation_colocated=None, + sampling_params=None, is_reward_model=False, ) @@ -645,6 +651,7 @@ def test_setup_distributed_with_cpu_offload( cpu_offload=True, # CPU offload enabled offload_optimizer_for_logprob=False, is_generation_colocated=None, + sampling_params=None, is_reward_model=False, ) @@ -712,6 +719,7 @@ def mock_runtime_config(self, mock_autoconfig): cpu_offload=False, offload_optimizer_for_logprob=False, is_generation_colocated=None, + sampling_params=None, is_reward_model=False, ) @@ -1469,6 +1477,7 @@ def test_setup_model_with_cpu_offload( cpu_offload=True, # CPU offload enabled offload_optimizer_for_logprob=False, is_generation_colocated=None, + sampling_params=None, is_reward_model=False, ) diff --git a/tests/unit/models/automodel/test_automodel_train.py b/tests/unit/models/automodel/test_automodel_train.py index d04077cdaa..be246d41c5 100644 --- a/tests/unit/models/automodel/test_automodel_train.py +++ b/tests/unit/models/automodel/test_automodel_train.py @@ -96,7 +96,7 @@ def base_cfg(): return { "dtensor_cfg": {"sequence_parallel": False}, "sequence_packing": {"train_mb_tokens": 256}, - "generation": {"temperature": 1.0}, + "generation": {"temperature": 1.0, "top_p": 1.0, "top_k": None}, } diff --git a/tests/unit/models/megatron/megatron_data_actors.py b/tests/unit/models/megatron/megatron_data_actors.py index 0687cf076b..baa08fa0a0 100644 --- a/tests/unit/models/megatron/megatron_data_actors.py +++ b/tests/unit/models/megatron/megatron_data_actors.py @@ -574,122 +574,82 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): _get_pack_sequence_parameters_for_megatron, ) - # Test 1: Basic configuration - no parallelism, no FP8 - megatron_cfg = { - "tensor_model_parallel_size": 1, - "sequence_parallel": False, - "pipeline_model_parallel_size": 1, - "context_parallel_size": 1, - } max_seq_len = 1023 - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - - if pad_individual != 1 or pad_packed != 1 or pad_to is not None: - return { - "success": False, - "error": f"Expected pad_individual=1, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", - } - - # Test 2: Context parallelism only - megatron_cfg = { - "tensor_model_parallel_size": 1, - "sequence_parallel": False, - "pipeline_model_parallel_size": 1, - "context_parallel_size": 4, - } - - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - - if pad_individual != 4 * 2 or pad_packed != 1 or pad_to is not None: - return { - "success": False, - "error": f"Expected pad_individual=4*2, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", - } - - # Test 3: Tensor parallelism with sequence parallelism - megatron_cfg = { - "tensor_model_parallel_size": 2, - "sequence_parallel": True, - "pipeline_model_parallel_size": 1, - "context_parallel_size": 1, - } - - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - - expected_individual = 2 # tp_size when SP is enabled - if pad_individual != 2 or pad_packed != 1 or pad_to is not None: - return { - "success": False, - "error": f"Expected pad_individual=2, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + # test with different combinations of parallelism + for tp, sp, pp, cp, expected_individual, expected_packed in [ + [1, False, 1, 1, 1, 1], # no parallelism + [2, True, 1, 1, 2, 1], # tp + [2, False, 1, 1, 1, 1], # tp+sp + [1, False, 1, 4, 8, 1], # cp + [2, True, 1, 4, 16, 1], # cp+tp+sp + [1, False, 4, 1, 1, 1], # pp + ]: + megatron_cfg = { + "tensor_model_parallel_size": tp, + "sequence_parallel": sp, + "pipeline_model_parallel_size": pp, + "context_parallel_size": cp, } + pad_individual, pad_packed, pad_to = ( + _get_pack_sequence_parameters_for_megatron( + megatron_cfg, expected_individual, max_seq_len + ) + ) - # Test 4: Tensor parallelism without sequence parallelism - megatron_cfg = { - "tensor_model_parallel_size": 2, - "sequence_parallel": False, - "pipeline_model_parallel_size": 1, - "context_parallel_size": 1, - } - - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - - if pad_individual != 1 or pad_packed != 1 or pad_to is not None: - return { - "success": False, - "error": f"Expected pad_individual=1, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + if pp > 1: + if ( + pad_individual != expected_individual + or pad_packed != expected_packed + or pad_to != max_seq_len + ): + return { + "success": False, + "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to={max_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } + else: + if ( + pad_individual != expected_individual + or pad_packed != expected_packed + or pad_to is not None + ): + return { + "success": False, + "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } + + # Edge case - different max_seq_len values with PP + for test_seq_len in [512, 2048, 4096]: + megatron_cfg = { + "tensor_model_parallel_size": 1, + "sequence_parallel": False, + "pipeline_model_parallel_size": 2, + "context_parallel_size": 1, } - # Test 5: Pipeline parallelism - megatron_cfg = { - "tensor_model_parallel_size": 1, - "sequence_parallel": False, - "pipeline_model_parallel_size": 4, - "context_parallel_size": 1, - } - - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) + pad_individual, pad_packed, pad_to = ( + _get_pack_sequence_parameters_for_megatron( + megatron_cfg, 1, test_seq_len + ) + ) - if pad_individual != 1 or pad_packed != 1 or pad_to != max_seq_len: - return { - "success": False, - "error": f"Expected pad_individual=1, pad_packed=1, pad_to={max_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", - } + if pad_individual != 1 or pad_packed != 1 or pad_to != test_seq_len: + return { + "success": False, + "error": f"Expected pad_individual=1, pad_packed=1, pad_to={test_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } - # Test 6: Combined CP and TP with SP - megatron_cfg = { - "tensor_model_parallel_size": 2, - "sequence_parallel": True, - "pipeline_model_parallel_size": 1, - "context_parallel_size": 4, - } + return {"success": True, "error": None} - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + def run_all_get_pack_sequence_parameters_for_megatron_fp8_tests(self): + """Test _get_pack_sequence_parameters_for_megatron function with various configurations with FP8 enabled.""" + from nemo_rl.models.megatron.data import ( + _get_pack_sequence_parameters_for_megatron, ) - expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size - if ( - pad_individual != expected_individual - or pad_packed != 1 - or pad_to is not None - ): - return { - "success": False, - "error": f"Expected pad_individual={expected_individual}, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", - } + max_seq_len = 1023 - # Test 7: FP8 enabled with default recipe + # Test 1: FP8 enabled with default recipe megatron_cfg = { "tensor_model_parallel_size": 1, "sequence_parallel": False, @@ -704,7 +664,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, 1, max_seq_len ) if pad_individual != 1 or pad_packed != 16 or pad_to is not None: @@ -713,7 +673,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "error": f"Expected pad_individual=1, pad_packed=16, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 8: FP8 enabled with blockwise recipe + # Test 2: FP8 enabled with blockwise recipe megatron_cfg = { "tensor_model_parallel_size": 1, "sequence_parallel": False, @@ -728,7 +688,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, 1, max_seq_len ) if pad_individual != 1 or pad_packed != 128 or pad_to is not None: @@ -737,7 +697,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "error": f"Expected pad_individual=1, pad_packed=128, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 9: FP8 with CP and TP+SP + # Test 3: FP8 with CP and TP+SP megatron_cfg = { "tensor_model_parallel_size": 2, "sequence_parallel": True, @@ -751,12 +711,13 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): }, } + expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size + expected_packed = 128 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, expected_individual, max_seq_len ) - expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size - expected_packed = 128 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size if ( pad_individual != expected_individual or pad_packed != expected_packed @@ -767,7 +728,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 10: All parallelism types with FP8 and PP + # Test 4: All parallelism types with FP8 and PP megatron_cfg = { "tensor_model_parallel_size": 2, "sequence_parallel": True, @@ -781,27 +742,28 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): }, } - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) + def _round_up_to_multiple_of(x, y): + return (x + y - 1) // y * y expected_individual = 2 * 2 * 2 # cp_size * 2 * tp_size expected_packed = 16 * 2 * 2 * 2 # divisor * cp_size * 2 * tp_size + expected_pad_to = _round_up_to_multiple_of(max_seq_len, expected_packed) - def _round_up_to_multiple_of(x, y): - return (x + y - 1) // y * y + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, expected_individual, max_seq_len + ) if ( pad_individual != expected_individual or pad_packed != expected_packed - or pad_to != _round_up_to_multiple_of(max_seq_len, expected_packed) + or pad_to != expected_pad_to ): return { "success": False, "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to={max_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 11: FP8 disabled explicitly + # Test 5: FP8 disabled explicitly megatron_cfg = { "tensor_model_parallel_size": 1, "sequence_parallel": False, @@ -816,7 +778,7 @@ def _round_up_to_multiple_of(x, y): } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, 1, max_seq_len ) if pad_individual != 1 or pad_packed != 1 or pad_to is not None: @@ -825,7 +787,7 @@ def _round_up_to_multiple_of(x, y): "error": f"Expected pad_individual=1, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 12: Missing fp8_cfg (should default to disabled) + # Test 6: Missing fp8_cfg (should default to disabled) megatron_cfg = { "tensor_model_parallel_size": 1, "sequence_parallel": False, @@ -835,7 +797,7 @@ def _round_up_to_multiple_of(x, y): } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, 1, max_seq_len ) if pad_individual != 1 or pad_packed != 1 or pad_to is not None: @@ -844,7 +806,7 @@ def _round_up_to_multiple_of(x, y): "error": f"Expected pad_individual=1, pad_packed=1, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 13: Edge case - very large parallelism values + # Test 7: Edge case - very large parallelism values megatron_cfg = { "tensor_model_parallel_size": 8, "sequence_parallel": True, @@ -858,12 +820,13 @@ def _round_up_to_multiple_of(x, y): }, } + expected_individual = 8 * 2 * 8 # cp_size * 2 * tp_size = 128 + expected_packed = 128 * 8 * 2 * 8 # divisor * cp_size * 2 * tp_size = 16384 + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, expected_individual, max_seq_len ) - expected_individual = 8 * 2 * 8 # cp_size * 2 * tp_size = 128 - expected_packed = 128 * 8 * 2 * 8 # divisor * cp_size * 2 * tp_size = 16384 if ( pad_individual != expected_individual or pad_packed != expected_packed @@ -874,26 +837,7 @@ def _round_up_to_multiple_of(x, y): "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 14: Edge case - different max_seq_len values with PP - for test_seq_len in [512, 2048, 4096]: - megatron_cfg = { - "tensor_model_parallel_size": 1, - "sequence_parallel": False, - "pipeline_model_parallel_size": 2, - "context_parallel_size": 1, - } - - pad_individual, pad_packed, pad_to = ( - _get_pack_sequence_parameters_for_megatron(megatron_cfg, test_seq_len) - ) - - if pad_individual != 1 or pad_packed != 1 or pad_to != test_seq_len: - return { - "success": False, - "error": f"Expected pad_individual=1, pad_packed=1, pad_to={test_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", - } - - # Test 15: FP8 with MXFP8 recipe + # Test 8: FP8 with MXFP8 recipe megatron_cfg = { "tensor_model_parallel_size": 1, "sequence_parallel": False, @@ -908,7 +852,7 @@ def _round_up_to_multiple_of(x, y): } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len + megatron_cfg, 1, max_seq_len ) if pad_individual != 1 or pad_packed != 32 or pad_to is not None: @@ -917,7 +861,7 @@ def _round_up_to_multiple_of(x, y): "error": f"Expected pad_individual=1, pad_packed=32, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 16: FP8 with MXFP8 recipe, CP, and TP+SP + # Test 9: FP8 with MXFP8 recipe, CP, and TP+SP megatron_cfg = { "tensor_model_parallel_size": 2, "sequence_parallel": True, @@ -931,13 +875,13 @@ def _round_up_to_multiple_of(x, y): }, } - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size expected_packed = 32 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, expected_individual, max_seq_len + ) + if ( pad_individual != expected_individual or pad_packed != expected_packed @@ -948,7 +892,7 @@ def _round_up_to_multiple_of(x, y): "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - # Test 17: FP8 with MXFP8 recipe, CP, TP+SP, and PP + # Test 10: FP8 with MXFP8 recipe, CP, TP+SP, and PP megatron_cfg = { "tensor_model_parallel_size": 2, "sequence_parallel": True, @@ -962,17 +906,18 @@ def _round_up_to_multiple_of(x, y): }, } - pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( - megatron_cfg, max_seq_len - ) - expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size expected_packed = 32 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size * pp_size + expected_pad_to = _round_up_to_multiple_of(max_seq_len, expected_packed) + + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, expected_individual, max_seq_len + ) if ( pad_individual != expected_individual or pad_packed != expected_packed - or pad_to != _round_up_to_multiple_of(max_seq_len, expected_packed) + or pad_to != expected_pad_to ): return { "success": False, diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index 3610b77d9f..14a5cd89c1 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -415,6 +415,7 @@ def test_get_microbatch_iterator_sequence_packing( "pipeline_model_parallel_size": 1, "context_parallel_size": 1, }, + "make_sequence_length_divisible_by": 1, } ( @@ -512,6 +513,7 @@ def test_get_microbatch_iterator_auto_detects_seq_length_key( "pipeline_model_parallel_size": 1, "context_parallel_size": 1, }, + "make_sequence_length_divisible_by": 1, } get_microbatch_iterator( @@ -864,12 +866,14 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_ """Comprehensive test of pack sequences functionality without context parallelism.""" worker_group = get_pack_sequence_parameters_setup - # Run all tests in a single call to the actor - futures = worker_group.run_all_workers_single_data( - "run_all_get_pack_sequence_parameters_for_megatron_tests" - ) - results = ray.get(futures) - - # Check that all workers succeeded - for i, result in enumerate(results): - assert result["success"], f"Worker {i} failed: {result['error']}" + for test_name in [ + "run_all_get_pack_sequence_parameters_for_megatron_tests", + "run_all_get_pack_sequence_parameters_for_megatron_fp8_tests", + ]: + # Run all tests in a single call to the actor + futures = worker_group.run_all_workers_single_data(test_name) + results = ray.get(futures) + + # Check that all workers succeeded + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 793faffc44..948ace54b7 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -763,6 +763,9 @@ def test_generation_colocation_detection(self): config = { "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "colocated": {"enabled": True}, }, "precision": "bfloat16", diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index da3c9179f7..eccf41defb 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -376,7 +376,7 @@ def test_forward_applies_temperature_scaling_for_loss( cfg = { "sequence_packing": {"enabled": False}, - "generation": {"temperature": 0.7}, + "generation": {"temperature": 0.7, "top_p": 1.0, "top_k": None}, } post_processor = LossPostProcessor(loss_fn=MagicMock(), cfg=cfg) sampling_params = TrainingSamplingParams( @@ -420,7 +420,7 @@ def test_forward_applies_temperature_scaling_for_logprobs( cfg = { "sequence_packing": {"enabled": False}, - "generation": {"temperature": 0.5}, + "generation": {"temperature": 0.5, "top_p": 1.0, "top_k": None}, } post_processor = LogprobsPostProcessor(cfg=cfg) sampling_params = TrainingSamplingParams( @@ -465,7 +465,7 @@ def test_forward_applies_temperature_scaling_for_topk( cfg = { "sequence_packing": {"enabled": False}, "megatron_cfg": {"context_parallel_size": 1}, - "generation": {"temperature": 1.5}, + "generation": {"temperature": 1.5, "top_p": 1.0, "top_k": None}, } post_processor = TopkLogitsPostProcessor(cfg=cfg, k=5) sampling_params = TrainingSamplingParams( diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 363c23ae09..cef068c7b7 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -54,9 +54,9 @@ def create_test_config( "generation": { "backend": "hf", "temperature": 1.0, - "max_new_tokens": 16, # Small number of tokens for testing "top_p": 1.0, "top_k": None, + "max_new_tokens": 16, # Small number of tokens for testing "stop_token_ids": None, "stop_strings": None, "colocated": { diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index 648cbbaa1b..aa5ec5c0c5 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -70,9 +70,9 @@ def create_test_config( "generation": { "backend": "hf", "temperature": 1.0, - "max_new_tokens": 16, # Small number of tokens for testing "top_p": 1.0, "top_k": None, + "max_new_tokens": 16, # Small number of tokens for testing "stop_token_ids": None, "stop_strings": None, "colocated": { diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 1835c9908b..ffd2be0445 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -81,9 +81,9 @@ def create_megatron_test_config( "generation": { "backend": generation_backend, "temperature": 1.0, - "max_new_tokens": 32, # Small number of tokens for testing "top_p": 1.0, "top_k": None, + "max_new_tokens": 32, # Small number of tokens for testing "stop_token_ids": None, "stop_strings": None, "mcore_generation_config": { @@ -178,6 +178,7 @@ def create_megatron_test_config( "fp8_param": True, }, }, + "make_sequence_length_divisible_by": tp, "optimizer": None, # Remove default FSDP optimizer "scheduler": None, # Remove default scheduler "max_grad_norm": 1.0, @@ -1747,6 +1748,7 @@ def test_megatron_context_parallel_topk_agreement(tiny_qwen2_model_path): ) # Enable context parallel config_cp["megatron_cfg"]["context_parallel_size"] = 2 + config_cp["make_sequence_length_divisible_by"] *= 4 # Enable sequence packing config_cp["sequence_packing"] = { @@ -2004,6 +2006,7 @@ def test_megatron_context_parallel_logprob_agreement(tiny_llama_model_path): ) # Enable context parallel config_cp["megatron_cfg"]["context_parallel_size"] = 2 + config_cp["make_sequence_length_divisible_by"] *= 4 # Enable sequence packing config_cp["sequence_packing"] = { @@ -2198,6 +2201,7 @@ def test_megatron_context_parallel_training_agreement(tiny_llama_model_path): ) # Enable context parallel config_cp["megatron_cfg"]["context_parallel_size"] = 2 + config_cp["make_sequence_length_divisible_by"] *= 4 config_cp["train_global_batch_size"] = 2 # Enable sequence packing diff --git a/tests/unit/models/policy/test_policy_validation.py b/tests/unit/models/policy/test_policy_validation.py index e5aedeeb12..6e8fceb760 100644 --- a/tests/unit/models/policy/test_policy_validation.py +++ b/tests/unit/models/policy/test_policy_validation.py @@ -71,9 +71,9 @@ def create_dtensor_config( "generation": { "backend": "hf", "temperature": 1.0, - "max_new_tokens": 16, "top_p": 1.0, "top_k": None, + "max_new_tokens": 16, "stop_token_ids": None, "stop_strings": None, "colocated": { @@ -127,9 +127,9 @@ def create_megatron_config( "generation": { "backend": "hf", "temperature": 1.0, - "max_new_tokens": 16, "top_p": 1.0, "top_k": None, + "max_new_tokens": 16, "stop_token_ids": None, "stop_strings": None, "colocated": { diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index f94da69c54..33240d9288 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -70,8 +70,10 @@ }, "max_grad_norm": 1.0, "generation": { - "backend": "vllm", "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "backend": "vllm", "colocated": {"enabled": True}, }, } diff --git a/tools/refit_verifier.py b/tools/refit_verifier.py index 30bdadb62d..7cd240f564 100644 --- a/tools/refit_verifier.py +++ b/tools/refit_verifier.py @@ -155,10 +155,12 @@ def setup_configs(args, tokenizer): "learning_rate": 0.0001, "logprob_batch_size": 1, "generation": { + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, "max_total_sequence_length": args.max_sequence_length, "max_new_tokens": args.max_sequence_length, "do_sample": False, - "temperature": 1.0, "pad_token_id": tokenizer.eos_token_id, "colocated": { "enabled": True, From 0e6e0f783e788f0523bc01a7d84cc83e8b056bf0 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Mar 2026 18:27:09 -0700 Subject: [PATCH 20/20] update threshold Signed-off-by: Yuki Huang --- tests/functional/grpo_multiturn.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh index 9af8a8dc97..cb1f6d6517 100755 --- a/tests/functional/grpo_multiturn.sh +++ b/tests/functional/grpo_multiturn.sh @@ -29,7 +29,7 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE policy.max_total_sequence_length=1024 \ policy.train_global_batch_size=4 \ policy.train_micro_batch_size=1 \ - policy.generation.top_p=0.99 \ + policy.generation.top_p=0.9 \ policy.generation.top_k=8000 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ @@ -41,5 +41,4 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS uv run tests/check_metrics.py $JSON_METRICS \ - 'max(data["train/token_mult_prob_error"]) < 1.1' - + 'median(data["train/token_mult_prob_error"]) < 1.1'