diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml new file mode 100644 index 0000000000..34d76a7421 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.yaml @@ -0,0 +1,37 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + temperature: 0.8 + top_p: 0.9 + top_k: 50 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml new file mode 100644 index 0000000000..82b5639f73 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.yaml @@ -0,0 +1,35 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_steps: 500 +checkpointing: + enabled: false + checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron + save_period: 100 +policy: + model_name: meta-llama/Llama-3.2-1B-Instruct + tokenizer: + name: meta-llama/Llama-3.2-1B-Instruct + optimizer: null + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 + temperature: 0.6 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-llama3.2-1b-instruct-1n8g-megatron + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-llama3.2-1b-instruct-1n8g-megatron +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index c61cb5f0ce..28bf864dad 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -25,9 +25,12 @@ ChunkedDistributedGatherLogprob, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, - from_parallel_logits_to_logprobs, + compute_logprobs_from_logits, gather_logits_at_global_indices, - get_logprobs_from_vocab_parallel_logits, +) +from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + need_top_k_or_top_p_filtering, ) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -192,6 +195,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -205,6 +209,26 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + # sampling mask mismatch handling + # vllm samples token X from top-k/p filtered distribution -> generation_logprobs[X] is always finite (e.g., -5.41) + # during training: policy computes logprobs with same top-k/p settings, but the distribution can be slightly different + # token X may fall outside the training policy's top-k/p set -> prev_logprobs[X] = -inf + # Detect positions with -inf in any logprobs (generation_logprobs is always finite for valid tokens) + prev_neginf_positions = torch.isinf(prev_logprobs) + prev_neginf_count = (prev_neginf_positions & mask.bool()).sum().item() + if prev_neginf_count > 0: + print( + f"[WARNING]: {prev_neginf_count}/{int(mask.sum().item())} valid tokens have -inf in prev_logprobs " + "(policy top-k/top-p mismatch). Masking out these positions." + ) + + # Update mask for actor loss (only based on prev_logprobs) + mask = mask * (~prev_neginf_positions).float() + prev_logprobs = torch.where(mask.bool(), prev_logprobs, 0.0) + # token_mult_prob_error # See more details and other metrics in docs/guides/grpo.md#metrics lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) @@ -269,56 +293,73 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) + curr_logprobs = compute_logprobs_from_logits( + next_token_logits, + data["input_ids"], + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + seq_index=seq_index, + sampling_params=sampling_params, + ) - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - curr_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_token_logits_wo_last = next_token_logits[ - :, :-1 - ] # Remove last position's logits - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits_wo_last, dim=-1 - ) - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - curr_logprobs = next_token_logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + # Handle -inf in curr_logprobs as well (same top-k/top-p mismatch issue) + curr_is_neginf = torch.isinf(curr_logprobs) + curr_neginf_count = (curr_is_neginf & mask.bool()).sum().item() + if curr_neginf_count > 0: + print( + f"[WARNING]: {curr_neginf_count} additional -inf positions detected in curr_logprobs, masking out." + ) + + # mask out -inf positions in curr_logprobs + mask = mask * (~curr_is_neginf).float() + curr_logprobs = torch.where(mask.bool(), curr_logprobs, 0.0) # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: + # When top-k/top-p filtering is enabled, we need special handling for KL: + # - reference_policy_logprobs is computed **without** filtering (see use_reference_model) + # - curr_logprobs is computed **with** filtering (for actor loss compatibility) + # - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs + # - For importance weights, we also use unfiltered curr_logprobs_for_kl since we're + # reweighting samples from π_gen_filtered to π_curr_unfiltered + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + # Compute unfiltered logprobs for KL calculation + curr_logprobs_for_kl = compute_logprobs_from_logits( + next_token_logits, + data["input_ids"], + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + seq_index=seq_index, + sampling_params=None, # No filtering for KL + ) + else: + curr_logprobs_for_kl = curr_logprobs + if self.use_on_policy_kl_approximation: # See: docs/guides/grpo.md#on-policy-kl-approximation + # Use curr_logprobs_for_kl (unfiltered when filtering is enabled) for importance weights + # This correctly reweights samples from π_gen to π_curr (unfiltered) kl_importance_weights = torch.exp( - curr_logprobs - generation_logprobs + curr_logprobs_for_kl - generation_logprobs ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) else: - kl_importance_weights = torch.ones_like(curr_logprobs) + kl_importance_weights = torch.ones_like(curr_logprobs_for_kl) + kl = ( kl_importance_weights * self.reference_policy_kl_penalty * calculate_kl( - logprobs=curr_logprobs, + logprobs=curr_logprobs_for_kl, logprobs_reference=reference_policy_logprobs, kl_type=self.reference_policy_kl_type, input_clamp_value=self.kl_input_clamp_value, @@ -548,6 +589,7 @@ def __call__( context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position @@ -556,37 +598,18 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) seq_index = data.get("seq_index", None) - next_token_logits = next_token_logits.to(torch.float32) + token_logprobs = compute_logprobs_from_logits( + next_token_logits, + data["input_ids"], + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + seq_index=seq_index, + sampling_params=sampling_params, + ) - # Gather the logprobs for the actual next tokens - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + # Apply masking to avoid NaN when logprobs are -inf at masked positions + token_logprobs = torch.where(mask.bool(), token_logprobs, 0.0) if dpo_loss: ## shape: [batch_size] @@ -691,7 +714,10 @@ def __call__( data: BatchedDataDict[PreferenceLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: + del sampling_params # not used + sample_mask = data["sample_mask"] rewards = rewards.squeeze(-1) @@ -806,41 +832,25 @@ def _dpo_loss( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] seq_index = data.get("seq_index", None) - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + token_logprobs = compute_logprobs_from_logits( + next_token_logits, + data["input_ids"], + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + seq_index=seq_index, + sampling_params=sampling_params, + ) + + # Apply masking to avoid NaN when logprobs are -inf at masked positions + token_logprobs = torch.where(token_mask.bool(), token_logprobs, 0.0) ref_logprobs = data["reference_policy_logprobs"][:, :-1] @@ -864,6 +874,7 @@ def __call__( # type: ignore vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -880,6 +891,7 @@ def __call__( # type: ignore context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, + sampling_params=sampling_params, ) sft_loss_chosen, sft_loss_rejected = self.split_output_tensor(sft_loss) sft_loss_chosen = masked_mean( @@ -900,6 +912,7 @@ def __call__( # type: ignore vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + sampling_params=sampling_params, ) dpo_loss = ( @@ -941,6 +954,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[Tensor, dict[str, Any]]: """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" unpadded_cu_seqlens = self.cu_seqlens_q @@ -992,6 +1006,7 @@ def __call__( vocab_parallel_rank=vocab_parallel_rank, vocab_parallel_group=vocab_parallel_group, context_parallel_group=context_parallel_group, + sampling_params=sampling_params, ) loss_accum += loss for k, v in metrics.items(): @@ -1057,8 +1072,16 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: TrainingSamplingParams | None = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" + # The generation sampling params top-k and top-p are not supported yet for distillation loss + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + raise ValueError( + "Generation sampling params top-k and top-p are not supported yet for distillation loss" + ) # Basic shapes input_ids = data["input_ids"] batch_size = input_ids.shape[0] diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..d18893cb9e 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -17,6 +17,12 @@ import torch from torch.distributed.tensor import DTensor, distribute_tensor +from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, + need_top_k_or_top_p_filtering, +) + @torch.no_grad() def _compute_distributed_log_softmax( @@ -255,6 +261,345 @@ def backward( return grad_input, None, None, None, None, None, None +class DistributedLogprobWithSampling(torch.autograd.Function): + """Custom autograd function for computing log probabilities with top-k/top-p sampling. + + This function materializes the full vocabulary by converting from vocab-parallel to + batch-sequence-parallel layout, applies filtering, and computes log probabilities. + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + top_k: int | None, + top_p: float, + inference_only: bool = False, + ) -> torch.Tensor: + """Forward pass for sampling-based logprob computation. + + Args: + vocab_parallel_logits: [B, S, V_local] logits sharded by vocab + target: [B, S] target token ids (already shifted) + tp_group: Tensor parallel process group + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + inference_only: If True, don't save tensors for backward + + Returns: + Log probabilities [B, S] + """ + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + if BS % world_size != 0: + raise ValueError( + f"B*S={BS} must be divisible by tensor parallel size {world_size} when using top-p/top-k sampling. " + "Please set policy.make_sequence_length_divisible_by to tensor parallel size." + ) + BS_local = BS // world_size + + # Reshape to 2D for all_to_all + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + + # Flatten target: [B, S] -> [BS] + target_flat = target.flatten() # [BS] + + # Extract local portion + start_idx = rank * BS_local + end_idx = (rank + 1) * BS_local + target_local = target_flat[start_idx:end_idx] # [BS_local] + + # All-to-all to get batch-sequence parallel logits + seq_parallel_logits = all_to_all_vp2sq(reshaped_vocab_parallel_logits, tp_group) + + # Apply top-k and top-p filtering locally (returns keep_mask for gradient) + logits, keep_mask = apply_top_k_top_p( + seq_parallel_logits, top_k=top_k, top_p=top_p + ) + + # Compute log softmax + log_probs = torch.nn.functional.log_softmax( + logits.to(dtype=torch.float32), dim=-1 + ) + + # Gather log probs for target tokens + token_logprobs = torch.gather( + log_probs, -1, target_local.unsqueeze(-1) + ).squeeze(-1) + + # All-gather across TP to get full sequence [BS] + gathered_list = [torch.empty_like(token_logprobs) for _ in range(world_size)] + torch.distributed.all_gather(gathered_list, token_logprobs, group=tp_group) + token_logprobs = torch.cat(gathered_list, dim=0) # [BS] + + # Reshape back to [B, S] + token_logprobs = token_logprobs.view(B, S) + + if not inference_only: + # Save softmax and mask for backward + softmax_output = log_probs.exp() + ctx.save_for_backward(softmax_output, target_local, keep_mask) + ctx.tp_group = tp_group + ctx.world_size = world_size + ctx.rank = rank + ctx.BS_local = BS_local + ctx.B = B + ctx.S = S + + return token_logprobs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None]: + """Backward pass for sampling-based logprob computation.""" + grad_output = grad_outputs[0] # [B, S] + softmax_output, target_local, keep_mask = ctx.saved_tensors + tp_group = ctx.tp_group + world_size = ctx.world_size + rank = ctx.rank + BS_local = ctx.BS_local + B = ctx.B + S = ctx.S + + # Flatten and extract local portion + grad_output_flat = grad_output.flatten() # [BS] + start_idx = rank * BS_local + end_idx = (rank + 1) * BS_local + grad_output_local = grad_output_flat[start_idx:end_idx] # [BS_local] + + # Compute gradient + V = softmax_output.shape[-1] + is_chosen = torch.nn.functional.one_hot(target_local, num_classes=V) + grad_logits_local = is_chosen.float().sub_(softmax_output) + grad_logits_local.mul_(grad_output_local.unsqueeze(-1)) + + # Apply keep_mask to gradients - filtered tokens don't get gradients + if keep_mask is not None: + grad_logits_local.mul_(keep_mask) + + # Convert back to vocab-parallel: [BS_local, V] -> [BS, V_local] + grad_vocab_parallel = all_to_all_sq2vp(grad_logits_local, tp_group) + grad_vocab_parallel = grad_vocab_parallel.view(B, S, V // world_size) + + return grad_vocab_parallel, None, None, None, None, None + + +class ChunkedDistributedLogprobWithSampling(torch.autograd.Function): + """Chunked version of DistributedLogprobWithSampling for memory efficiency. + + Uses delayed rematerialization to avoid storing large intermediate tensors. + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + top_k: int | None, + top_p: float, + chunk_size: int, + inference_only: bool = False, + ) -> torch.Tensor: + """Forward pass with chunked processing. + + Args: + vocab_parallel_logits: [B, S, V_local] logits sharded by vocab + target: [B, S] target token ids (already shifted) + tp_group: Tensor parallel process group + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + chunk_size: Chunk size for memory efficiency (in sequence dimension) + inference_only: If True, don't save tensors for backward + + Returns: + Log probabilities [B, S] + """ + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + if BS % world_size != 0: + raise ValueError( + f"B*S={BS} must be divisible by tensor parallel size {world_size} when using top-p/top-k sampling. " + "Please set policy.make_sequence_length_divisible_by to tensor parallel size." + ) + + # Convert chunk_size from sequence dimension to batch-sequence dimension + effective_chunk_size = chunk_size * B + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + + # Make sure the effective chunk size is divisible by the world size + # This ensure all the chunks (including the last one) meet the world size requirement. + if effective_chunk_size % world_size != 0: + raise ValueError( + f"Effective chunk size {effective_chunk_size} = chunk_size {chunk_size} * B {B} must be divisible " + f"by the tensor parallel size {world_size}." + ) + + # Flatten target: [B, S] -> [BS] + target_flat = target.flatten() # [BS] + + # Process in chunks + num_chunks = (BS + effective_chunk_size - 1) // effective_chunk_size + all_token_logprobs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * effective_chunk_size + chunk_end = min(BS, (chunk_idx + 1) * effective_chunk_size) + current_chunk_size = chunk_end - chunk_start + local_chunk_size = current_chunk_size // world_size + + # Slice the chunk + vocab_parallel_logits_chunk = reshaped_vocab_parallel_logits[ + chunk_start:chunk_end, : + ] + + # Extract target chunk for this rank + target_chunk = target_flat[chunk_start:chunk_end] + target_local = target_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # All-to-all to get batch-sequence parallel logits + seq_parallel_logits_chunk = all_to_all_vp2sq( + vocab_parallel_logits_chunk, tp_group + ) + + # Apply top-k and top-p filtering locally + logits_chunk, _ = apply_top_k_top_p( + seq_parallel_logits_chunk, top_k=top_k, top_p=top_p + ) + + # Compute log softmax + log_probs_chunk = torch.nn.functional.log_softmax( + logits_chunk.to(dtype=torch.float32), dim=-1 + ) + + # Gather log probs for target tokens in this chunk + token_logprobs_chunk = torch.gather( + log_probs_chunk, -1, target_local.unsqueeze(-1) + ).squeeze(-1) + + # All-gather across TP to get full chunk [current_chunk_size] + gathered_list = [ + torch.empty_like(token_logprobs_chunk) for _ in range(world_size) + ] + torch.distributed.all_gather( + gathered_list, token_logprobs_chunk, group=tp_group + ) + token_logprobs_chunk = torch.cat(gathered_list, dim=0) + + all_token_logprobs.append(token_logprobs_chunk) + + # Concatenate all chunks and reshape + token_logprobs = torch.cat(all_token_logprobs, dim=0) # [BS] + token_logprobs = token_logprobs.view(B, S) + + if not inference_only: + ctx.save_for_backward(vocab_parallel_logits) + ctx.target = target + ctx.tp_group = tp_group + ctx.top_k = top_k + ctx.top_p = top_p + ctx.chunk_size = chunk_size + + return token_logprobs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + """Backward pass with chunked rematerialization.""" + grad_output = grad_outputs[0] # [B, S] + (vocab_parallel_logits,) = ctx.saved_tensors + target = ctx.target + tp_group = ctx.tp_group + top_k = ctx.top_k + top_p = ctx.top_p + chunk_size = ctx.chunk_size + + world_size = torch.distributed.get_world_size(tp_group) + rank = torch.distributed.get_rank(tp_group) + B, S, V_local = vocab_parallel_logits.shape + BS = B * S + + effective_chunk_size = chunk_size * B + reshaped_vocab_parallel_logits = vocab_parallel_logits.view(BS, V_local) + target_flat = target.flatten() + + num_chunks = (BS + effective_chunk_size - 1) // effective_chunk_size + grad_output_flat = grad_output.flatten() + + all_grad_chunks = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * effective_chunk_size + chunk_end = min(BS, (chunk_idx + 1) * effective_chunk_size) + current_chunk_size = chunk_end - chunk_start + local_chunk_size = current_chunk_size // world_size + + # Rematerialize forward pass for this chunk + vocab_parallel_logits_chunk = reshaped_vocab_parallel_logits[ + chunk_start:chunk_end, : + ] + + target_chunk = target_flat[chunk_start:chunk_end] + target_local = target_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # Rematerialize all-to-all + seq_parallel_logits_chunk = all_to_all_vp2sq( + vocab_parallel_logits_chunk, tp_group + ) + + # Rematerialize filtering + logits_chunk, keep_mask = apply_top_k_top_p( + seq_parallel_logits_chunk, top_k=top_k, top_p=top_p + ) + + # Rematerialize softmax + log_probs_chunk = torch.nn.functional.log_softmax( + logits_chunk.to(dtype=torch.float32), dim=-1 + ) + softmax_chunk = log_probs_chunk.exp() + + # Extract local portion of grad_output + grad_chunk = grad_output_flat[chunk_start:chunk_end] + grad_local = grad_chunk[ + rank * local_chunk_size : (rank + 1) * local_chunk_size + ] + + # Compute gradient: (one_hot - softmax) * grad_output + V = softmax_chunk.shape[-1] + is_chosen = torch.nn.functional.one_hot(target_local, num_classes=V) + grad_logits_local = is_chosen.float().sub_(softmax_chunk) + grad_logits_local.mul_(grad_local.unsqueeze(-1)) + + # Apply keep_mask + if keep_mask is not None: + grad_logits_local.mul_(keep_mask) + + # Convert back to vocab-parallel + grad_vocab_parallel_chunk = all_to_all_sq2vp(grad_logits_local, tp_group) + all_grad_chunks.append(grad_vocab_parallel_chunk) + + grad_vocab_parallel = torch.cat(all_grad_chunks, dim=0) + grad_vocab_parallel = grad_vocab_parallel.view(B, S, V_local) + + return grad_vocab_parallel, None, None, None, None, None, None + + class ChunkedDistributedGatherLogprob(torch.autograd.Function): """Compute distributed log-softmax once and gather logprobs at given global indices. @@ -390,6 +735,7 @@ def dtensor_from_parallel_logits_to_logprobs( inference_only: bool = False, seq_index: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -405,6 +751,7 @@ def dtensor_from_parallel_logits_to_logprobs( seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. chunk_size (Optional[int]): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -436,25 +783,51 @@ def dtensor_from_parallel_logits_to_logprobs( else: target = target.roll(shifts=-1, dims=-1) - if chunk_size is not None: - logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - chunk_size, - tp_group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # logprobs is sharded on the sequence dimension. @@ -475,6 +848,7 @@ def from_parallel_logits_to_logprobs( inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -489,6 +863,7 @@ def from_parallel_logits_to_logprobs( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -509,25 +884,51 @@ def from_parallel_logits_to_logprobs( cp_rank = torch.distributed.get_rank(cp_group) target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) - if chunk_size is not None: - logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - chunk_size, - tp_group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + target, + tp_group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # we need to gather the logits by context parallelism @@ -552,6 +953,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. @@ -602,25 +1004,51 @@ def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation - if chunk_size is not None: - probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - chunk_size, - group, - inference_only, - ).contiguous() + use_sampling = sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ) + + if use_sampling: + assert sampling_params is not None # Type narrowing + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + group, + sampling_params.top_k, + sampling_params.top_p, + chunk_size, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + group, + sampling_params.top_k, + sampling_params.top_p, + inference_only, + ).contiguous() else: - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - group, - inference_only, - ).contiguous() + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + chunk_size, + group, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() # Remove batch dimension for filtering probs = probs.squeeze(0) @@ -780,6 +1208,7 @@ def get_logprobs_from_vocab_parallel_logits( input_ids: torch.Tensor | DTensor, seq_index: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -794,6 +1223,7 @@ def get_logprobs_from_vocab_parallel_logits( seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, with shape [sequence_length]. chunk_size (Optional[int]): Sequence dimension chunk size for computing log probabilities. + sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering and temperature scaling. Returns: torch.Tensor: Log probabilities for the given input IDs. @@ -822,6 +1252,7 @@ def get_logprobs_from_vocab_parallel_logits( inference_only=not torch.is_grad_enabled(), seq_index=seq_index, chunk_size=chunk_size, + sampling_params=sampling_params, ) @@ -1054,3 +1485,180 @@ def backward( grad_input = torch.cat(grads, dim=1) if len(grads) > 1 else grads[0] return grad_input, None, None, None + + +def compute_logprobs_from_logits( + next_token_logits: torch.Tensor, + input_ids: torch.Tensor, + vocab_parallel_rank: Optional[int], + vocab_parallel_group: Optional[torch.distributed.ProcessGroup], + context_parallel_group: Optional[torch.distributed.ProcessGroup], + seq_index: Optional[torch.Tensor], + sampling_params: Optional[TrainingSamplingParams], +) -> torch.Tensor: + """Compute token log-probabilities from logits, handling parallel and non-parallel cases. + + This function handles three cases: + 1. Vocab parallel (Megatron-style): uses from_parallel_logits_to_logprobs + 2. DTensor: uses get_logprobs_from_vocab_parallel_logits + 3. Non-parallel: applies top-k/top-p filtering, log_softmax, and gather + + Args: + next_token_logits: Logits tensor of shape [batch_size, seq_len, vocab_size] + input_ids: Input token IDs of shape [batch_size, seq_len] + vocab_parallel_rank: Rank in the vocab parallel group (required if vocab_parallel_group is provided) + vocab_parallel_group: Process group for vocab parallelism + context_parallel_group: Process group for context parallelism + seq_index: Sequence index tensor for DTensor path + sampling_params: Sampling parameters for top-k/top-p filtering + + Returns: + Token log-probabilities of shape [batch_size, seq_len - 1] + """ + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + token_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + sampling_params=sampling_params, + ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, + input_ids, + seq_index=seq_index, + sampling_params=sampling_params, + ) + else: + next_token_logits_wo_last = next_token_logits[ + :, :-1 + ] # Remove last position's logits + # Apply top-k and top-p filtering + next_token_logits_wo_last, _ = apply_top_k_top_p( + next_token_logits_wo_last, + top_k=sampling_params.top_k if sampling_params is not None else None, + top_p=sampling_params.top_p if sampling_params is not None else 1.0, + ) + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = input_ids[:, 1:].cuda() # Skip first token + token_logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + return token_logprobs + + +def all_to_all_vp2sq( + vocab_parallel_logits: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + """Convert vocab-parallel logits to batch-sequence-parallel logits via all-to-all. + + Note: This partitions the flattened B*S dimension, not just S. The input vocab_parallel_logits + need to be 2D tensor. + + Transforms [BS, V_local] -> [BS_local, V] where: + - V_local = V / tp_size (vocab is sharded) + - BS_local = BS / tp_size (batch-sequence will be sharded) + - Requires BS to be divisible by tp_size + + Args: + vocab_parallel_logits: [BS, V_local] tensor with vocab dimension sharded + tp_group: Tensor parallel process group + + Returns: + Batch-sequence-parallel logits [BS_local, V] with batch-sequence dimension sharded + """ + if vocab_parallel_logits.ndim != 2: + raise ValueError( + "For all_to_all_vp2sq, vocab_parallel_logits must be a 2D tensor, " + f"got {vocab_parallel_logits.ndim}D tensor with shape {vocab_parallel_logits.shape}" + ) + + world_size = torch.distributed.get_world_size(tp_group) + BS, V_local = vocab_parallel_logits.shape + + if BS % world_size != 0: + raise ValueError( + f"BS={BS} must be divisible by tensor parallel size {world_size}. " + f"Set policy.make_sequence_length_divisible_by to ensure divisibility." + ) + + BS_local = BS // world_size + + # Flatten and perform all-to-all: exchanges B*S chunks for vocab slices + input_flat = vocab_parallel_logits.flatten() + output_flat = torch.empty_like(input_flat) + torch.distributed.all_to_all_single(output_flat, input_flat, group=tp_group) + + # Rearrange output: merge vocab slices from all ranks into full vocabulary + # Equivalent to: "(w bs v) -> bs (w v)", w=world_size, bs=BS_local, v=V_local + output_tensor = output_flat.view(world_size, BS_local, V_local) + output_tensor = output_tensor.permute(1, 0, 2) + output_tensor = output_tensor.reshape(BS_local, world_size * V_local) + + return output_tensor + + +def all_to_all_sq2vp( + seq_parallel_logits: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, +) -> torch.Tensor: + """Convert batch-sequence-parallel logits to vocab-parallel logits via all-to-all. + + Inverse operation of all_to_all_vp2sq. + + Transforms [BS_local, V] -> [BS, V_local] where: + - BS_local = BS / tp_size (batch-sequence is sharded) + - V_local = V / tp_size (vocab will be sharded) + + Args: + seq_parallel_logits: [BS_local, V] tensor with batch-sequence dimension sharded + tp_group: Tensor parallel process group + + Returns: + Vocab-parallel logits [BS, V_local] with vocab dimension sharded + """ + if seq_parallel_logits.ndim != 2: + raise ValueError( + "For all_to_all_sq2vp, seq_parallel_logits must be a 2D tensor, " + f"got {seq_parallel_logits.ndim}D tensor with shape {seq_parallel_logits.shape}" + ) + + world_size = torch.distributed.get_world_size(tp_group) + BS_local, V = seq_parallel_logits.shape + + if V % world_size != 0: + raise ValueError( + f"Vocabulary size {V} must be divisible by tensor parallel size {world_size}" + ) + + V_local = V // world_size + + # Rearrange input: split vocab into chunks for sending to different ranks + # Equivalent to: "bs (w v) -> (w bs v)", w=world_size, bs=BS_local, v=V_local + input_reshaped = seq_parallel_logits.view(BS_local, world_size, V_local) + input_permuted = input_reshaped.permute(1, 0, 2).contiguous() + input_flat = input_permuted.flatten() + + # Perform all-to-all: exchanges vocab slices for B*S chunks + output_flat = torch.empty_like(input_flat) + torch.distributed.all_to_all_single(output_flat, input_flat, group=tp_group) + + # Reshape output: merge B*S slices from all ranks into full batch-sequence dimension + output_tensor = output_flat.reshape(world_size * BS_local, V_local) + + return output_tensor diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 6138dfdb43..dfee1190c3 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -42,12 +42,6 @@ compute_spec_decode_metrics, ) -# Global thresholds for top_k and top_p validation. -# While top-k/p are not supported, these values allow for token filtering while the logprobs should be compatible. -# See https://github.com/NVIDIA-NeMo/RL/issues/69 and https://github.com/NVIDIA-NeMo/RL/issues/237 for more details. -TOP_K_THRESHOLD = 8000 # Allow top_k >= 8000 (effectively no filtering) -TOP_P_THRESHOLD = 0.99 # Allow top_p >= 0.99 (close to 1.0) - class VllmGeneration(GenerationInterface): def __init__( @@ -92,30 +86,16 @@ def __init__( ) # Validate sampling parameters early to avoid resource allocation with unsupported configs. - # The vLLM sampler patch only supports temperature scaling and does not handle top_p/top_k correctly. - # However, we allow values above certain thresholds for token filtering purposes. - top_k = self.cfg["top_k"] - if top_k is not None and top_k != -1 and top_k < TOP_K_THRESHOLD: + top_k: int | None = self.cfg.get("top_k", None) + if top_k is not None and top_k != -1 and top_k < 1: raise ValueError( - ( - f"top_k sampling with values < {TOP_K_THRESHOLD} is not supported because the vLLM V1 engine " - "does not return logprobs after top_k filtering. Values >= {TOP_K_THRESHOLD} are allowed " - "for token filtering purposes. If you understand the implications and still want to use " - f"a lower top_k value, please manually comment out this check. Got top_k={top_k}. " - "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." - ) + f"top_k valid values: i) None or -1: no filtering. ii) >= 1: top-k filtering. Got top_k={top_k}." ) top_p: float = self.cfg.get("top_p", 1.0) - if top_p < TOP_P_THRESHOLD: + if top_p <= 0: raise ValueError( - ( - f"top_p sampling with values < {TOP_P_THRESHOLD} is not supported because the vLLM V1 engine " - "does not return logprobs after top_p filtering. Values >= {TOP_P_THRESHOLD} are allowed " - "for token filtering purposes. If you understand the implications and still want to use " - f"a lower top_p value, please manually comment out this check. Got top_p={top_p}. " - "See https://github.com/NVIDIA-NeMo/RL/issues/69 for more details." - ) + f"top_p valid values: i) 1.0: no filtering. ii) (0, 1]: top-p filtering. Got top_p={top_p}." ) # Ensure all required VllmConfig fields are present diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 28af36b11b..f8469dda4b 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -32,6 +32,7 @@ from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.policy.utils import TrainingSamplingParams def _round_up_to_multiple(value: int, multiple: int) -> int: @@ -52,7 +53,7 @@ def forward_step_arbitrary_loss( pack_sequences: bool = False, defer_fp32_logits: Optional[bool] = None, cp_normalize: bool = True, - policy_cfg: Optional[dict] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ): """Forward training step with support for packed sequences and context parallelism. @@ -66,7 +67,7 @@ def forward_step_arbitrary_loss( pack_sequences (bool): Whether to pack sequences for efficiency defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 cp_normalize (bool): Whether to normalize the loss by the cp_size - policy_cfg (Optional[dict]): Policy configuration containing generation parameters + sampling_params (Optional[TrainingSamplingParams]): Sampling parameters for temperature scaling and top-k/top-p filtering Notes on packed sequences with context parallelism (CP): - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) @@ -114,12 +115,8 @@ def forward_step_arbitrary_loss( # Apply temperature scaling to logits for training # This matches the dtensor worker's _apply_temperature_scaling in the train method - if ( - policy_cfg is not None - and "generation" in policy_cfg - and policy_cfg["generation"] is not None - ): - output_tensor.div_(policy_cfg["generation"]["temperature"]) + if sampling_params is not None and sampling_params.temperature != 1.0: + output_tensor.div_(sampling_params.temperature) # Unpack the output tensor if we did packed sequences if pack_sequences and packed_seq_params is not None: @@ -140,6 +137,7 @@ def forward_step_arbitrary_loss( vocab_parallel_rank=get_tensor_model_parallel_rank(), vocab_parallel_group=get_tensor_model_parallel_group(), context_parallel_group=get_context_parallel_group(), + sampling_params=sampling_params, ) if cp_normalize: diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 6061b0f143..97f011b704 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -15,6 +15,7 @@ import gc import os import traceback +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, cast @@ -75,6 +76,15 @@ } +# Default chunk size for top-k/top-p filtering. +# The sort operation in top-p filtering is memory intensive because it creates +# intermediate tensors of shape [bsz, seq_len, vocab_size] for both sorted values +# and indices. For large vocab sizes (e.g., 152K) and long sequences (e.g., 32K), +# this can cause OOM. Chunking along the sequence dimension reduces peak memory. +# Different chunk sizes have minor performance differences. +TOP_K_TOP_P_CHUNK_SIZE: int = 256 + + class IPCProtocol(Enum): """IPC protocol constants for ZMQ weight streaming.""" @@ -82,68 +92,44 @@ class IPCProtocol(Enum): ACK = "ack" -def apply_top_k_top_p( - logits: torch.Tensor, - top_k: Optional[int] = None, - top_p: Optional[float] = None, -) -> torch.Tensor: - """Apply top-k and top-p masks to the logits. +def need_top_k_filtering(top_k: int | None) -> bool: + """Check if top-k filtering is needed.""" + return top_k is not None and top_k != -1 - Simplified version of VLLM's implementation for scalar parameters. - Based on VLLM's implementation: - https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py - SPDX-License-Identifier: Apache-2.0 - Copyright contributors to the vLLM project +def need_top_p_filtering(top_p: float | None) -> bool: + """Check if top-p filtering is needed.""" + return top_p is not None and top_p != 1.0 - Args: - logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] - top_k: Top-k sampling parameter. Set to -1 to consider all tokens. - top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. - Returns: - Filtered logits with sampling parameters applied - """ - if top_p is None or top_p == 1.0: - if top_k is None or top_k == -1: - return logits - # Avoid sorting vocab for top-k only case - return apply_top_k_only(logits, top_k) - - # Apply top-p (requires sorting) - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if top_k is not None and top_k != -1: - # Apply top-k first - top_k_index = logits_sort.size(-1) - top_k - # Get all the top_k values - need to broadcast the index across all dimensions - index_tensor = torch.full( - logits_sort.shape[:-1], - top_k_index, - device=logits_sort.device, - dtype=torch.long, - ) - top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) - top_k_mask = logits_sort < top_k_threshold - logits_sort.masked_fill_(top_k_mask, -float("inf")) +def need_top_k_or_top_p_filtering(top_k: int | None, top_p: float | None) -> bool: + """Check if top-k or top-p filtering is needed.""" + return need_top_k_filtering(top_k) or need_top_p_filtering(top_p) - # Apply top-p - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_p_mask = probs_sum <= 1 - top_p - # at least one - top_p_mask[..., -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - # Re-sort the probabilities - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits +@dataclass +class TrainingSamplingParams: + """Training-specific sampling parameters to match generation parameters. + Used to ensure consistency between training and inference by applying the same sampling strategy during + logprob computation. Not directly using vLLM's SamplingParams class to avoid dependency on vLLM in this env. -def apply_top_k_only( + Attributes: + top_k: Top-k filtering parameter (None or -1 to disable) + top_p: Top-p filtering parameter (1.0 to disable) + temperature: Temperature for scaling logits (default: 1.0) + """ + + top_k: int | None = None + top_p: float = 1.0 + temperature: float = 1.0 + + +@torch.no_grad() +def _apply_top_k_only_fn( logits: torch.Tensor, - top_k: int, -) -> torch.Tensor: + top_k: int | None, +) -> tuple[torch.Tensor, torch.Tensor | None]: """Apply top-k mask to the logits. Simplified version of VLLM's implementation for scalar parameters. @@ -155,27 +141,207 @@ def apply_top_k_only( Copyright contributors to the vLLM project Args: - logits: Input logits tensor of shape [batch_size, seq_len, vocab_size] + logits: Input logits tensor of shape [*, vocab_size]. top_k: Top-k sampling parameter. Returns: - Filtered logits with top-k applied + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed. """ - if top_k >= logits.shape[-1] or top_k == -1: - return logits + if not need_top_k_filtering(top_k): + return logits, None # Get top-k values and create mask + assert top_k is not None # Type narrowing top_k_values, _ = torch.topk(logits, top_k, dim=-1) threshold = top_k_values[..., -1:].expand_as(logits) - mask = logits >= threshold + keep_mask = logits >= threshold # Apply mask: keep top-k values, set others to -inf logits = torch.where( - mask, + keep_mask, logits, torch.tensor(-float("inf"), device=logits.device, dtype=logits.dtype), ) - return logits + return logits, keep_mask + + +@torch.no_grad() +def _apply_top_k_top_p_fn( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with chunking for memory efficiency. + + The sort operation in top-p filtering is memory intensive because it creates + intermediate tensors of shape [num_tokens, vocab_size] for both sorted values + and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. + This function flattens the input to 2D and processes in chunks along the token + dimension (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] + or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens + chunk_size: Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates + tokens to be kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_p_filtering(top_p): + if not need_top_k_filtering(top_k): + return logits, None + # Avoid sorting vocab for top-k only case + filtered_logits, top_k_keep_mask = _apply_top_k_only_fn(logits, top_k) + return filtered_logits, top_k_keep_mask + + # Save original shape and flatten to 2D for consistent chunking + original_shape = logits.shape + vocab_size = logits.shape[-1] + logits = logits.view(-1, vocab_size) # [*, vocab_size] -> [num_tokens, vocab_size] + num_tokens = logits.shape[0] + + chunk_size = chunk_size if chunk_size is not None else num_tokens + + # Pre-allocate output tensors + filtered_logits = torch.empty_like(logits) + keep_mask = torch.empty( + num_tokens, vocab_size, dtype=torch.bool, device=logits.device + ) + + for start_idx in range(0, num_tokens, chunk_size): + end_idx = min(start_idx + chunk_size, num_tokens) + chunk_logits = logits[start_idx:end_idx, :] + + # Sort this chunk + logits_sort, logits_idx = chunk_logits.sort(dim=-1, descending=False) + top_k_keep_mask_chunk = None + + if need_top_k_filtering(top_k): + assert top_k is not None # Type narrowing + # Apply top-k first + top_k_index = logits_sort.size(-1) - top_k + index_tensor = torch.full( + logits_sort.shape[:-1], + top_k_index, + device=logits_sort.device, + dtype=torch.long, + ) + top_k_threshold = logits_sort.gather(-1, index_tensor.unsqueeze(-1)) + top_k_keep_mask_chunk = logits_sort >= top_k_threshold + logits_sort.masked_fill_(~top_k_keep_mask_chunk, -float("inf")) + + # Apply top-p + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_p_keep_mask_chunk = probs_sum > 1 - top_p + # at least one + top_p_keep_mask_chunk[..., -1] = True + logits_sort.masked_fill_(~top_p_keep_mask_chunk, -float("inf")) + + # Scatter back to original order + chunk_filtered = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + if top_k_keep_mask_chunk is not None: + chunk_mask = torch.logical_and(top_k_keep_mask_chunk, top_p_keep_mask_chunk) + else: + chunk_mask = top_p_keep_mask_chunk + chunk_mask = chunk_mask.scatter(dim=-1, index=logits_idx, src=chunk_mask) + + # Store results + filtered_logits[start_idx:end_idx, :] = chunk_filtered + keep_mask[start_idx:end_idx, :] = chunk_mask + + # Restore original shape + filtered_logits = filtered_logits.view(original_shape) + keep_mask = keep_mask.view(original_shape) + + return filtered_logits, keep_mask + + +class ApplyTopKTopP(torch.autograd.Function): + """Autograd function for top-k and top-p filtering with proper gradient handling.""" + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx, + logits: torch.Tensor, + top_k: Optional[int], + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k/top-p filtering and save masks for backward. + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 or None to consider all tokens. + top_p: Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + """ + filtered_logits, keep_mask = _apply_top_k_top_p_fn( + logits, top_k, top_p, chunk_size + ) + + # Save masks for backward pass + ctx.save_for_backward(keep_mask) + + return filtered_logits, keep_mask + + @staticmethod + def backward(ctx, *grad_outputs: torch.Tensor): + """Backward pass: mask out gradients for filtered tokens.""" + grad_filtered_logits = grad_outputs[0] + (keep_mask,) = ctx.saved_tensors + + # Apply masks to gradients - masked out tokens should not receive gradients + if keep_mask is not None: + grad_filtered_logits = grad_filtered_logits.masked_fill(~keep_mask, 0.0) + + return grad_filtered_logits, None, None, None + + +def apply_top_k_top_p( + logits: torch.Tensor, + top_k: int | None, + top_p: float, + chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply top-k and top-p masks to the logits with proper gradient handling. + + Simplified version of VLLM's implementation for scalar parameters. + + When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. + Processing is done in chunks (controlled by chunk_size) to reduce peak memory. + + Based on VLLM's implementation: + https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py + SPDX-License-Identifier: Apache-2.0 + Copyright contributors to the vLLM project + + Args: + logits: Input logits tensor of shape [*, vocab_size]. + top_k: Top-k sampling parameter. Set to -1 to consider all tokens. + top_p: Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens. + chunk_size: Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE. + + Returns: + filtered_logits: Filtered logits tensor with the same shape as input logits. + keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be + kept, 0 (False) indicates tokens to be masked. + """ + if not need_top_k_or_top_p_filtering(top_k, top_p): + return logits, None + return ApplyTopKTopP.apply(logits, top_k, top_p, chunk_size) def resolve_model_class(model_name: str) -> Any: diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 6028506f92..9e77bba1c6 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -71,8 +71,12 @@ ScoreOutputSpec, ) from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, configure_dynamo_cache, get_runtime_env_for_policy_worker, + need_top_k_filtering, + need_top_p_filtering, resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker @@ -172,6 +176,16 @@ def __init__( if "generation" in config and config["generation"] is not None: self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + if "generation" in self.cfg and self.cfg["generation"] is not None: + generation_cfg = self.cfg["generation"] + self.sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + else: + self.sampling_params = None + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. if not self.is_generation_colocated: @@ -477,8 +491,21 @@ def create_context_parallel_ctx( # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: - if "generation" in self.cfg and self.cfg["generation"] is not None: - logits.div_(self.cfg["generation"]["temperature"]) + if self.sampling_params is not None and self.sampling_params.temperature != 1.0: + logits.div_(self.sampling_params.temperature) + return logits + + def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: + """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" + if self.sampling_params is not None and ( + need_top_k_filtering(self.sampling_params.top_k) + or need_top_p_filtering(self.sampling_params.top_p) + ): + logits, _ = apply_top_k_top_p( + logits, + top_k=self.sampling_params.top_k, + top_p=self.sampling_params.top_p, + ) return logits @staticmethod @@ -789,6 +816,7 @@ def train( mb, global_valid_seqs, global_valid_toks, + sampling_params=self.sampling_params, ) del logits @@ -1079,6 +1107,7 @@ def get_logprobs( input_ids_dtensor, seq_index_tensor, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) assert token_logprobs.shape[1] == seq_len - 1 @@ -1088,6 +1117,7 @@ def get_logprobs( logits, input_ids, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) else: if logprob_chunk_size is not None: @@ -1105,6 +1135,10 @@ def get_logprobs( chunk_logits = logits[ :, chunk_start:chunk_end, : ].to(torch.float32) + # Apply top-k and top-p filtering + chunk_logits = self._apply_top_k_top_p_filtering( + chunk_logits + ) log_probs = torch.nn.functional.log_softmax( chunk_logits, dim=-1 ) @@ -1112,7 +1146,9 @@ def get_logprobs( log_probs = torch.cat(chunked_log_probs, dim=1) del chunked_log_probs else: + # Apply top-k and top-p filtering logits = logits.to(torch.float32) + logits = self._apply_top_k_top_p_filtering(logits) log_probs = torch.nn.functional.log_softmax( logits, dim=-1 ) @@ -1607,8 +1643,10 @@ def get_topk_logits( def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. - On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references - On exit: Restores original references and re-flips cuda/cpu + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. + Also disables top-k/top-p filtering since the reference policy's distribution + is different from the current policy, making filtered logprobs incompatible. + On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. """ with torch.no_grad(): try: @@ -1622,11 +1660,29 @@ def use_reference_model(self) -> Generator[None, None, None]: val = to_local_if_dtensor(v) val.copy_(self.reference_model_state_dict[k]) + # Temporarily disable top-k/top-p filtering for reference policy logprobs. + # The reference policy has different weights, so its top-k/top-p set is + # inherently different from the current policy. Using filtered logprobs + # would cause -inf mismatches that cannot be resolved by masking. + # Note: We keep temperature scaling since it was applied to prev_logprobs. + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, # Disable top-k + top_p=1.0, # Disable top-p + temperature=saved_sampling_params.temperature, # Keep temperature + ) + else: + self.sampling_params = None + # - self.model is the original reference_model, now on CUDA # - curr_state_dict is the train model, now on CPU yield finally: + # Restore sampling_params + self.sampling_params = saved_sampling_params + # Restore train model state_dict for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4eb730e5a0..5ce42824ce 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -68,7 +68,10 @@ ScoreOutputSpec, ) from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + apply_top_k_top_p, get_runtime_env_for_policy_worker, + need_top_k_or_top_p_filtering, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import ( @@ -239,6 +242,16 @@ def __init__( # Initialize checkpoint manager self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + if "generation" in self.cfg and self.cfg["generation"] is not None: + generation_cfg = self.cfg["generation"] + self.sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + else: + self.sampling_params = None + # Validate configuration and prepare runtime settings runtime_config = validate_and_prepare_config( config=config, @@ -323,6 +336,24 @@ def __init__( _runtime_is_reward_model, # Duplicate, already set as _is_reward_model ) = runtime_config + def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: + if self.sampling_params is not None and self.sampling_params.temperature != 1.0: + logits.div_(self.sampling_params.temperature) + return logits + + def _apply_top_k_top_p_filtering(self, logits: torch.Tensor) -> torch.Tensor: + """Apply top-k and top-p filtering to the logits locally when TP is disabled.""" + sampling_params = self.sampling_params + if sampling_params is not None and need_top_k_or_top_p_filtering( + sampling_params.top_k, sampling_params.top_p + ): + logits, _ = apply_top_k_top_p( + logits, + top_k=sampling_params.top_k, + top_p=sampling_params.top_p, + ) + return logits + @wrap_with_nvtx_name("dtensor_policy_worker_v2/train") def train( self, @@ -796,9 +827,24 @@ def use_reference_model(self) -> Generator[None, None, None]: # - self.model is the original reference_model, now on CUDA # - curr_state_dict is the train model, now on CPU + + # Save and adjust sampling_params for reference model + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None + yield finally: + # Restore sampling_params + self.sampling_params = saved_sampling_params + # Restore train model state_dict for k, v in self.model.state_dict().items(): val = to_local_if_dtensor(v) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 798c4ea00a..0512bf7ca4 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -97,7 +97,10 @@ ColocatablePolicyInterface, LogprobOutputSpec, ) -from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker +from nemo_rl.models.policy.utils import ( + TrainingSamplingParams, + get_runtime_env_for_policy_worker, +) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -227,6 +230,17 @@ def __init__( self.offload_optimizer_for_logprob = ( runtime_config.offload_optimizer_for_logprob ) + + if "generation" in self.cfg and self.cfg["generation"] is not None: + generation_cfg = self.cfg["generation"] + self.sampling_params = TrainingSamplingParams( + top_k=generation_cfg.get("top_k", None), + top_p=generation_cfg.get("top_p", 1.0), + temperature=generation_cfg.get("temperature", 1.0), + ) + else: + self.sampling_params = None + self.is_generation_colocated = runtime_config.is_generation_colocated self.final_padded_vocab_size = runtime_config.final_padded_vocab_size @@ -344,7 +358,9 @@ def train( with ctx: forward_step = partial( - forward_step_arbitrary_loss, loss_fn=loss_fn, policy_cfg=self.cfg + forward_step_arbitrary_loss, + loss_fn=loss_fn, + sampling_params=self.sampling_params, ) all_mb_metrics = [] losses = [] @@ -596,8 +612,11 @@ def forward_step_fn( # Apply temperature scaling to logits for training # This matches the dtensor worker's _apply_temperature_scaling in the train method - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) + if ( + self.sampling_params is not None + and self.sampling_params.temperature != 1.0 + ): + output_tensor.div_(self.sampling_params.temperature) def collection_fn(output_tensor): stc = time.time() @@ -616,6 +635,7 @@ def collection_fn(output_tensor): inference_only=True, cp_group=get_context_parallel_group(), chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) else: token_logprobs = from_parallel_logits_to_logprobs( @@ -626,6 +646,7 @@ def collection_fn(output_tensor): tp_group=tp_grp, inference_only=True, chunk_size=logprob_chunk_size, + sampling_params=self.sampling_params, ) # Prepend 0 logprob for first token to maintain same sequence length as input @@ -698,6 +719,16 @@ def use_reference_model(self): # if isinstance(item, torch.Tensor): # self.model.state_dict()[name] = item.detach().to(device="cuda", non_blocking=True, copy=True) + saved_sampling_params = self.sampling_params + if saved_sampling_params is not None: + self.sampling_params = TrainingSamplingParams( + top_k=None, + top_p=1.0, + temperature=saved_sampling_params.temperature, + ) + else: + self.sampling_params = None + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: gc.collect() torch.cuda.empty_cache() @@ -708,6 +739,7 @@ def use_reference_model(self): finally: # Restore original references and device placement + self.sampling_params = saved_sampling_params self.model.load_state_dict(model_state_dict, strict=True) # for name, item in model_state_dict.items(): # if isinstance(item, torch.Tensor): @@ -797,8 +829,11 @@ def forward_step_fn( **multimodal_data, ) - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) + if ( + self.sampling_params is not None + and self.sampling_params.temperature != 1.0 + ): + output_tensor.div_(self.sampling_params.temperature) def collection_fn(_): # Only the last PP stage produces final logits/top-k; earlier stages return empty diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh new file mode 100755 index 0000000000..1cdfad6fd3 --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 12.5' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh new file mode 100755 index 0000000000..21dd483626 --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=500 +MAX_STEPS=500 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.01) < 1.05' \ + 'data["train/token_mult_prob_error"]["500"] < 1.1' \ + 'data["train/reward"]["500"] > 0.1' \ + 'mean(data["timing/train/total_step_time"], -6, -1) < 10.5' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 35b5cec0f3..095ba10062 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -23,6 +23,8 @@ tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-sampling.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron-temp0.6.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8-actckpt.v3.sh diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 8637ad22fe..3b7f7ecc8f 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -20,7 +20,9 @@ from nemo_rl.distributed.model_utils import ( ChunkedDistributedGatherLogprob, ChunkedDistributedLogprob, + ChunkedDistributedLogprobWithSampling, DistributedLogprob, + DistributedLogprobWithSampling, _compute_distributed_log_softmax, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, @@ -34,6 +36,9 @@ ) from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.policy.utils import ( + apply_top_k_top_p, +) @ray.remote(num_gpus=1) @@ -956,3 +961,299 @@ def test_distributed_logprob_all_tests( finally: cluster.shutdown() + + +@ray.remote(num_gpus=1) +class SamplingParamsTestActor: + def __init__(self, tp_size, sharding): + self.tp_size = tp_size + self.sharding = sharding + self.env_vars = dict(os.environ) + torch.distributed.init_process_group(backend="nccl") + self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) + + def test_top_k_top_p_filtering_forward_backward(self, top_k, top_p): + """Test top-k and top-p filtering logic including backward pass.""" + batch_size = 2 + seq_len = 4 + vocab_size = 100 + + torch.manual_seed(42) + logits = torch.randn( + batch_size, seq_len, vocab_size, device="cuda", requires_grad=True + ) + + filtered_logits, keep_mask = apply_top_k_top_p(logits, top_k=top_k, top_p=top_p) + + # Test 1: Verify top-k filtering + if top_k is not None: + for b in range(batch_size): + for s in range(seq_len): + topk_vals, topk_indices = torch.topk(logits[b, s], k=top_k) + topk_mask = torch.zeros( + vocab_size, dtype=torch.bool, device=logits.device + ) + topk_mask[topk_indices] = True + assert torch.all(torch.isinf(filtered_logits[b, s][~topk_mask])), ( + "Values outside top-k should be -inf" + ) + if top_p == 1.0: + assert not torch.any( + torch.isinf(filtered_logits[b, s][topk_mask]) + ), "Top-k values should not be -inf when top_p=1.0" + non_inf_count = (~torch.isinf(filtered_logits[b, s])).sum().item() + assert non_inf_count <= top_k, ( + f"Non-inf count {non_inf_count} exceeds top_k {top_k}" + ) + + # Test 2: Verify top-p filtering + if top_p < 1.0: + for b in range(batch_size): + for s in range(seq_len): + if top_k is not None: + topk_vals, topk_indices = torch.topk(logits[b, s], k=top_k) + temp_logits = torch.full_like(logits[b, s], float("-inf")) + temp_logits[topk_indices] = topk_vals + else: + temp_logits = logits[b, s] + probs = torch.nn.functional.softmax(temp_logits, dim=-1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumsum_probs = torch.cumsum(sorted_probs, dim=0) + cutoff_idx = torch.where(cumsum_probs > top_p)[0] + if len(cutoff_idx) > 0: + cutoff_idx = cutoff_idx[0].item() + 1 + else: + cutoff_idx = len(sorted_probs) + kept_indices = sorted_indices[:cutoff_idx] + for idx in kept_indices: + if not torch.isinf(filtered_logits[b, s, idx]): + continue + raise AssertionError(f"Index {idx} in top-p should not be -inf") + + # Test 3: No filtering case + if top_k is None and top_p >= 1.0: + torch.testing.assert_close( + filtered_logits, logits.detach(), rtol=1e-5, atol=1e-5 + ) + + # Test 4: Valid probabilities + probs = torch.nn.functional.softmax(filtered_logits, dim=-1) + assert torch.all(probs >= 0) and torch.all(probs <= 1), "Invalid probabilities" + assert torch.allclose( + probs.sum(dim=-1), torch.ones(batch_size, seq_len, device="cuda"), atol=1e-5 + ), "Probabilities don't sum to 1" + + # Test 5: Verify keep_mask alignment with filtered logits + if keep_mask is not None: + non_inf_mask = ~torch.isinf(filtered_logits.detach()) + assert torch.equal(keep_mask, non_inf_mask), ( + f"keep_mask doesn't match non-inf positions in filtered_logits! " + f"Mismatch count: {(keep_mask != non_inf_mask).sum().item()} out of {keep_mask.numel()}" + ) + + # Test 6: Backward pass + torch.manual_seed(44) + output_grad = torch.randn_like(filtered_logits) + non_inf_mask = ~torch.isinf(filtered_logits.detach()) + expected_grad = output_grad * non_inf_mask.float() + filtered_logits.backward(output_grad) + actual_grad = logits.grad + torch.testing.assert_close(actual_grad, expected_grad, rtol=1e-5, atol=1e-5) + + return {"success": True, "error": None, "top_k": top_k, "top_p": top_p} + + def test_distributed_logprob_with_sampling(self, top_k, top_p, chunk_size): + """Test DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling.""" + tp_group = self.tp_group + tp_rank = torch.distributed.get_rank(tp_group) + + batch_size = 4 + seq_len = 16 + vocab_size = 256 + vocab_part_size = vocab_size // self.tp_size + vocab_start_index = tp_rank * vocab_part_size + vocab_end_index = (tp_rank + 1) * vocab_part_size + + torch.manual_seed(42) + full_logits = torch.randn(batch_size, seq_len, vocab_size, device="cuda") + vocab_parallel_logits = ( + full_logits[:, :, vocab_start_index:vocab_end_index] + .clone() + .requires_grad_(True) + ) + + torch.manual_seed(43) + target = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + + # === Expected computation using full logits === + expected_logits_filtered, _ = apply_top_k_top_p( + full_logits.clone(), top_k=top_k, top_p=top_p + ) + expected_log_probs = torch.nn.functional.log_softmax( + expected_logits_filtered, dim=-1 + ) + expected_target_logprobs = torch.gather( + expected_log_probs, -1, target.unsqueeze(-1) + ).squeeze(-1) + + # === Actual computation using distributed function === + if chunk_size is None: + actual_logprobs = DistributedLogprobWithSampling.apply( + vocab_parallel_logits, + target, + tp_group, + top_k, + top_p, + False, + ) + else: + actual_logprobs = ChunkedDistributedLogprobWithSampling.apply( + vocab_parallel_logits, + target, + tp_group, + top_k, + top_p, + chunk_size, + False, + ) + + # === Forward pass validation === + torch.testing.assert_close( + actual_logprobs, expected_target_logprobs, rtol=1e-4, atol=1e-4 + ) + + # === Backward pass validation === + torch.manual_seed(44) + output_grad = torch.randn_like(actual_logprobs) + + expected_logits_filtered_grad = full_logits.clone().requires_grad_(True) + expected_logits_filtered_after_filter, _ = apply_top_k_top_p( + expected_logits_filtered_grad, top_k=top_k, top_p=top_p + ) + expected_log_probs_grad = torch.nn.functional.log_softmax( + expected_logits_filtered_after_filter, dim=-1 + ) + expected_target_logprobs_grad = torch.gather( + expected_log_probs_grad, -1, target.unsqueeze(-1) + ).squeeze(-1) + expected_target_logprobs_grad.backward(output_grad) + expected_grad = expected_logits_filtered_grad.grad[ + :, :, vocab_start_index:vocab_end_index + ].clone() + + actual_logprobs.backward(output_grad) + actual_grad = vocab_parallel_logits.grad.clone() + torch.testing.assert_close(actual_grad, expected_grad, rtol=1e-4, atol=1e-4) + + return { + "success": True, + "error": None, + "top_k": top_k, + "top_p": top_p, + "chunk_size": chunk_size, + } + + +SAMPLING_PARAMS_TEST_ACTOR_FQN = ( + f"{SamplingParamsTestActor.__module__}.SamplingParamsTestActor" +) + + +@pytest.fixture +def register_sampling_params_test_actor(): + """Register the SamplingParamsTestActor for use in tests.""" + original_registry_value = ACTOR_ENVIRONMENT_REGISTRY.get( + SAMPLING_PARAMS_TEST_ACTOR_FQN + ) + ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = PY_EXECUTABLES.SYSTEM + yield SAMPLING_PARAMS_TEST_ACTOR_FQN + if SAMPLING_PARAMS_TEST_ACTOR_FQN in ACTOR_ENVIRONMENT_REGISTRY: + if original_registry_value is None: + del ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] + else: + ACTOR_ENVIRONMENT_REGISTRY[SAMPLING_PARAMS_TEST_ACTOR_FQN] = ( + original_registry_value + ) + + +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "top_k,top_p", + [ + (None, 1.0), # No filtering + (10, 1.0), # Only top-k + (None, 0.9), # Only top-p + (10, 0.9), # Both top-k and top-p + ], +) +def test_sampling_params_top_k_top_p( + register_sampling_params_test_actor, tp_size, top_k, top_p +): + """Test top-k and top-p filtering logic.""" + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip( + f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}" + ) + cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True) + try: + actor_fqn = register_sampling_params_test_actor + sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) + builder = RayWorkerBuilder(actor_fqn, tp_size, sharding) + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + futures = worker_group.run_all_workers_single_data( + "test_top_k_top_p_filtering_forward_backward", top_k=top_k, top_p=top_p + ) + results = ray.get(futures) + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + worker_group.shutdown(force=True) + finally: + cluster.shutdown() + + +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize( + "top_k,top_p", + [ + (10, 1.0), # Only top-k + (None, 0.9), # Only top-p + (10, 0.9), # Both top-k and top-p + ], +) +@pytest.mark.parametrize("chunk_size", [None, 4]) +def test_sampling_params_distributed_logprob( + register_sampling_params_test_actor, tp_size, top_k, top_p, chunk_size +): + """Test DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling.""" + if not torch.cuda.is_available() or torch.cuda.device_count() < tp_size: + pytest.skip( + f"Not enough GPUs available. Need {tp_size}, got {torch.cuda.device_count()}" + ) + cluster = RayVirtualCluster(bundle_ct_per_node_list=[tp_size], use_gpus=True) + try: + actor_fqn = register_sampling_params_test_actor + sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) + builder = RayWorkerBuilder(actor_fqn, tp_size, sharding) + worker_group = RayWorkerGroup( + cluster=cluster, + remote_worker_builder=builder, + workers_per_node=None, + sharding_annotations=sharding, + ) + futures = worker_group.run_all_workers_single_data( + "test_distributed_logprob_with_sampling", + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + ) + results = ray.get(futures) + for i, result in enumerate(results): + assert result["success"], f"Worker {i} failed: {result['error']}" + worker_group.shutdown(force=True) + finally: + cluster.shutdown() diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index c27a183b5c..ee1b06045c 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -382,43 +382,6 @@ def test_vllm_missing_required_config_key(cluster): print(f"Successfully caught missing config key with error: {error_message}") -def test_vllm_top_p_top_k_validation(cluster): - """Test that top_p and top_k validation works correctly with threshold-based logic.""" - # Test that values above thresholds are allowed - config_above_thresholds = deepcopy(basic_vllm_test_config) - config_above_thresholds["top_p"] = 0.99 # Above TOP_P_THRESHOLD - config_above_thresholds["top_k"] = 8000 # Above TOP_K_THRESHOLD - - # Should not raise an error - try: - VllmGeneration(cluster, config_above_thresholds) - print("Successfully initialized with top_p=0.99 and top_k=8000") - except Exception as e: - pytest.fail(f"Should not raise error with values above thresholds: {e}") - - # Test that values below thresholds are rejected - config_below_thresholds = deepcopy(basic_vllm_test_config) - config_below_thresholds["top_p"] = 0.9 # Below TOP_P_THRESHOLD - - with pytest.raises(ValueError) as excinfo: - VllmGeneration(cluster, config_below_thresholds) - - error_message = str(excinfo.value) - assert "top_p sampling with values < 0.99 is not supported" in error_message - print(f"Successfully caught low top_p value with error: {error_message}") - - # Test that low top_k values are rejected - config_low_top_k = deepcopy(basic_vllm_test_config) - config_low_top_k["top_k"] = 7999 # Below TOP_K_THRESHOLD - - with pytest.raises(ValueError) as excinfo: - VllmGeneration(cluster, config_low_top_k) - - error_message = str(excinfo.value) - assert "top_k sampling with values < 8000 is not supported" in error_message - print(f"Successfully caught low top_k value with error: {error_message}") - - def test_vllm_policy_generation(policy, test_input_data, tokenizer): """Test vLLM policy generation capabilities.""" # Test generation diff --git a/tests/unit/models/generation/test_vllm_logprobs_mode.py b/tests/unit/models/generation/test_vllm_logprobs_mode.py index f43f9e323c..fb3bd591c4 100644 --- a/tests/unit/models/generation/test_vllm_logprobs_mode.py +++ b/tests/unit/models/generation/test_vllm_logprobs_mode.py @@ -167,7 +167,7 @@ def test_processed_logprobs_matches_manual_computation(): scaled_logits_batched = scaled_logits.unsqueeze(0).unsqueeze( 0 ) # [1, 1, vocab_size] - filtered_logits_batched = apply_top_k_top_p( + filtered_logits_batched, _ = apply_top_k_top_p( scaled_logits_batched, top_k=top_k, top_p=top_p ) filtered_logits = filtered_logits_batched.squeeze(0).squeeze(0) # [vocab_size] @@ -257,7 +257,7 @@ def test_apply_top_k_top_p_matches_vllm_upstream(top_k, top_p, test_name): print(f"Testing: {test_name}") # Our implementation: expects [batch, seq, vocab], takes scalar k/p - our_result = apply_top_k_top_p(logits_3d.clone(), top_k=top_k, top_p=top_p) + our_result, _ = apply_top_k_top_p(logits_3d.clone(), top_k=top_k, top_p=top_p) # vLLM upstream: expects [batch, vocab], takes tensor k/p with shape [batch] # Process each sequence position separately (vLLM doesn't batch over seq_len) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 11515ec661..2e71c4e2c2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -17,6 +17,7 @@ from nemo_rl.algorithms.interfaces import LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.models.policy.utils import TrainingSamplingParams class SimpleLoss: @@ -31,6 +32,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: # Just return mean of logprobs as the loss for testing loss = next_token_logits.mean() @@ -55,6 +57,7 @@ def __call__( vocab_parallel_rank: Optional[int] = None, vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + sampling_params: Optional[TrainingSamplingParams] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: # logits shape: [batch_size, seq_len, vocab_size] # Get the next token logits for each position