diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index 3e16304fcf..98a0fbcfcc 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -162,6 +162,7 @@ runs: --shm-size=64g \ --env TRANSFORMERS_OFFLINE=0 \ --env HYDRA_FULL_ERROR=1 \ + --env HF_HUB_OFFLINE=1 \ --env HF_HOME=/home/TestData/nemo-rl/hf_home \ --env HF_DATASETS_CACHE=/home/TestData/nemo-rl/hf_datasets_cache \ --env NEMO_RL_REPO_DIR=/opt/nemo-rl \ diff --git a/.gitmodules b/.gitmodules index fbfc974332..d6e8586781 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/NeMo"] path = 3rdparty/NeMo-workspace/NeMo url = https://github.com/NVIDIA/NeMo.git - branch = https://github.com/NVIDIA/NeMo/tree/ashors/rl-qwen3-export + branch = pjin/ashors/rl-qwen3-export shallow = true [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index aaefedd1d1..5c42641e34 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit aaefedd1d13f4ccd5cd06a19e06f1df33589a235 +Subproject commit 5c42641e344a487c7ca5b253a7483f0af8ef40e6 diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 5d3daff3aa..b797afee17 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -41,6 +41,7 @@ policy: logprob_batch_size: 4 max_total_sequence_length: 512 precision: "bfloat16" + logprob_chunk_size: null dtensor_cfg: enabled: true @@ -53,6 +54,65 @@ policy: megatron_cfg: enabled: false + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen2ForCausalLM" + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + defer_fp32_logits: null + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" # See docs/design-docs/sequence-packing-and-dynamic-batching.md # for more details on dynamic batching and sequence packing. diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index e7cae09858..3040d20ffc 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -56,9 +56,6 @@ policy: lr_warmup_iters: 13 lr_warmup_init: 3.0e-8 - env_vars: - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" - generation: backend: "vllm" max_new_tokens: ${policy.max_total_sequence_length} diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml new file mode 100644 index 0000000000..fddd7726c1 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -0,0 +1,168 @@ +checkpointing: + enabled: True + checkpoint_dir: results/grpo-math-qwen3-30ba3b-megatron-tp4-32k + save_period: 3 + keep_top_k: 1 + metric_name: val_reward + higher_is_better: True + checkpoint_must_save_by: null + +grpo: + normalize_rewards: True + use_leave_one_out_baseline: True + max_num_steps: 3 + num_prompts_per_step: 64 + num_generations_per_prompt: 16 + max_rollout_turns: 1 + val_period: 3 + val_at_start: False + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: False + use_importance_sampling_correction: False + token_level_loss: True + ratio_clip_c: null + +policy: + model_name: "Qwen/Qwen3-30B-A3B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 1 + max_total_sequence_length: 32768 + precision: "bfloat16" + logprob_chunk_size: 2048 + + dtensor_cfg: + enabled: False + + dynamic_batching: + enabled: False + + sequence_packing: + enabled: False + + max_grad_norm: 1.0 + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + + optimizer: null # remove default FSDP optimizer + + scheduler: null # remove default FSDP scheduler + + megatron_cfg: + enabled: True + empty_unused_memory_level: 1 + converter_type: "LlamaForCausalLM" + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: True + pipeline_dtype: ${policy.precision} + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + freeze_moe_router: True + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + apply_rope_fusion: True + activation_checkpointing: True + defer_fp32_logits: True + + optimizer: + optimizer: "adam" + lr: 5.0e-7 + min_lr: 5.0e-8 + weight_decay: 0.0 + bf16: True + fp16: False + params_dtype: "float32" + + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + use_distributed_optimizer: True + use_precision_aware_optimizer: True + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 2 + lr_warmup_init: 5.0e-8 + + distributed_data_parallel_config: + grad_reduce_in_fp32: False + overlap_grad_reduce: True + overlap_param_gather: True + average_in_collective: True + use_custom_fsdp: False + data_parallel_sharding_strategy: "optim_grads_params" + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: False + precision: ${policy.precision} + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 + enforce_eager: True + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +data: + dataset_name: "OpenMathInstruct-2" + shuffle: true + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + +env: + math: + num_workers: 8 + +logger: + log_dir: logs/grpo-math-qwen3-30ba3b-megatron-tp4-32k + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: True + tensorboard_enabled: True + mlflow_enabled: False # Disable MLflow logging + monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: nemo-rl + name: "grpo-math-qwen3-30ba3b-megatron-tp4-32k" + tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 4 diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index d0a7b05be7..ff942d32d5 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -137,8 +137,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( "vocab_parallel_rank must be provided when vocab_parallel_group is provided" @@ -159,6 +157,7 @@ def __call__( next_token_logits, data["input_ids"], seq_index=seq_index ) else: + next_token_logits = next_token_logits.to(torch.float32) next_token_logits_wo_last = next_token_logits[ :, :-1 ] # Remove last position's logits @@ -327,8 +326,6 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) seq_index = data.get("seq_index", None) - next_token_logits = next_token_logits.to(torch.float32) - # Gather the logprobs for the actual next tokens if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( @@ -351,6 +348,7 @@ def __call__( ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logits = next_token_logits.to(torch.float32) next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) @@ -583,7 +581,6 @@ def _dpo_loss( sample_mask = data["sample_mask"] seq_index = data.get("seq_index", None) - next_token_logits = next_token_logits.to(torch.float32) if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( "vocab_parallel_rank must be provided when vocab_parallel_group is provided" @@ -605,6 +602,7 @@ def _dpo_loss( ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logits = next_token_logits.to(torch.float32) next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 5b6a2d57f2..29cc5eb6b7 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -77,11 +77,10 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func masked_target = target - vocab_start_index masked_target[target_mask] = 0 - log_softmax_output = _compute_distributed_log_softmax( - vocab_parallel_logits, group=group - ) - log_probs = log_softmax_output.clone() - softmax_output = log_softmax_output.exp_() + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) + softmax_output = log_probs.exp() log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) log_probs[target_mask] = 0.0 @@ -141,6 +140,121 @@ def backward( return grad_input, None, None, None, None, None, None +class ChunkedDistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + + The log probabilities computation is chunked in the sequence dimension + to mitigate GPU OOM (especially during backward pass). + In addition, logits casting from float16 or bfloat16 -> float32 is performed + inside the chunk loop to avoid materializing a whole float32 logits tensor. + + Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + log_probs = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + + log_probs = torch.gather( + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(vocab_parallel_logits, target_mask, masked_target) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + vocab_parallel_logits, target_mask, masked_target = ctx.saved_tensors + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + + partition_vocab_size = int(vocab_parallel_logits.shape[-1]) + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + + all_grad_input = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] + logits = logits.to(dtype=torch.float32) + + softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + softmax_output = softmax_output.exp() + + # 1 if it's the chosen log prob, 0 otherwise + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + + grad_input = is_chosen.float().sub_(softmax_output) + + grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) + + all_grad_input.append(grad_input) + + grad_input = torch.cat(all_grad_input, dim=1) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + def dtensor_from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, target: DTensor | torch.Tensor, @@ -149,6 +263,7 @@ def dtensor_from_parallel_logits_to_logprobs( tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, seq_index: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -163,6 +278,7 @@ def dtensor_from_parallel_logits_to_logprobs( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. + chunk_size (Optional[int]): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -194,23 +310,34 @@ def dtensor_from_parallel_logits_to_logprobs( else: target = target.roll(shifts=-1, dims=-1) - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: - # probs is sharded on the sequence dimension. + # logprobs is sharded on the sequence dimension. # Get full sequence tensor, vocab dim has been reduced already. - probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements) - probs = probs_dtensor.full_tensor()[:, sorted_indices] - assert probs.shape == target_shape + logprobs_dtensor = DTensor.from_local(logprobs, cp_mesh, cp_placements) + logprobs = logprobs_dtensor.full_tensor()[:, sorted_indices] + assert logprobs.shape == target_shape - return probs[:, :-1] + return logprobs[:, :-1] def from_parallel_logits_to_logprobs( @@ -221,6 +348,7 @@ def from_parallel_logits_to_logprobs( tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -234,6 +362,7 @@ def from_parallel_logits_to_logprobs( tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -254,25 +383,36 @@ def from_parallel_logits_to_logprobs( cp_rank = torch.distributed.get_rank(cp_group) target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # we need to gather the logits by context parallelism - probs = allgather_cp_sharded_tensor( - probs, cp_group, seq_dim=1 + logprobs = allgather_cp_sharded_tensor( + logprobs, cp_group, seq_dim=1 ) # , unpadded_seqlen=target.shape[1]) if pad_len > 0: - probs = probs[:, :-pad_len] + logprobs = logprobs[:, :-pad_len] - return probs[:, :-1] + return logprobs[:, :-1] def from_parallel_logits_to_logprobs_packed_sequences( @@ -285,6 +425,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( group: torch.distributed.ProcessGroup, inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. @@ -301,6 +442,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. @@ -334,14 +476,25 @@ def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - group, - inference_only, - ).contiguous() + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + chunk_size, + group, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() # Remove batch dimension for filtering probs = probs.squeeze(0) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 25ecaf8051..10f66966f7 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -687,6 +687,7 @@ def get_logprobs_from_vocab_parallel_logits( vocab_parallel_logits: DTensor, input_ids: torch.Tensor | DTensor, seq_index: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -700,6 +701,7 @@ def get_logprobs_from_vocab_parallel_logits( with shape [batch_size, seq_len]. seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, with shape [sequence_length]. + chunk_size (Optional[int]): Sequence dimension chunk size for computing log probabilities. Returns: torch.Tensor: Log probabilities for the given input IDs. @@ -727,4 +729,5 @@ def get_logprobs_from_vocab_parallel_logits( tp_group, inference_only=not torch.is_grad_enabled(), seq_index=seq_index, + chunk_size=chunk_size, ) diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index c5637d7096..872fff35ff 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -93,6 +93,7 @@ class MegatronConfig(TypedDict): freeze_moe_router: bool expert_tensor_parallel_size: int expert_model_parallel_size: int + defer_fp32_logits: NotRequired[bool] optimizer: NotRequired[MegatronOptimizerConfig] scheduler: NotRequired[MegatronSchedulerConfig] @@ -138,6 +139,7 @@ class PolicyConfig(TypedDict): train_global_batch_size: int train_micro_batch_size: int logprob_batch_size: NotRequired[int] + logprob_chunk_size: NotRequired[int] generation: NotRequired[GenerationConfig] generation_batch_size: NotRequired[ int diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 62bd12cfe3..b5acafbf71 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -879,6 +879,7 @@ def get_logprobs( if micro_batch_size is not None else self.cfg["logprob_batch_size"] ) + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) # dim 1 is always assumed to be the sequence dim, sanity check this here sequence_dim = 1 @@ -1035,21 +1036,48 @@ def get_logprobs( placements=[Shard(sequence_dim), Shard(-1)], ) - logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( logits, input_ids_dtensor, seq_index_tensor, + chunk_size=logprob_chunk_size, ) assert token_logprobs.shape[1] == seq_len - 1 else: if isinstance(logits, DTensor): - logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits, input_ids + logits, + input_ids, + chunk_size=logprob_chunk_size, ) else: + if logprob_chunk_size is not None: + logits_seq_len = int(logits.shape[1]) + num_chunks = ( + logits_seq_len + logprob_chunk_size - 1 + ) // logprob_chunk_size + chunked_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * logprob_chunk_size + chunk_end = min( + logits_seq_len, + (chunk_idx + 1) * logprob_chunk_size, + ) + chunk_logits = logits[ + :, chunk_start:chunk_end, : + ].to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + chunk_logits, dim=-1 + ) + chunked_log_probs.append(log_probs) + log_probs = torch.cat(chunked_log_probs, dim=1) + del chunked_log_probs + else: + logits = logits.to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) # Extract logprobs for each token in the sequence by gathering the logprob # corresponding to the next token at each position # Input shapes: @@ -1057,13 +1085,12 @@ def get_logprobs( # token_ids: [batch_size, sequence_length] - actual tokens # Output shape: [batch_size, sequence_length] - logprob of each token given previous # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - logits = outputs.logits.to(torch.float32) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] token_logprobs = log_probs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) + del log_probs del outputs, logits diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3a94ffa98c..4ee48901bc 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -217,6 +217,9 @@ def re_enable_float32_expert_bias(model_module): overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng_config.data_parallel_random_init, model_post_init_fns=model_post_init_fns, + wrap_cast_model_output_to_fp32=( + not policy_cfg["megatron_cfg"].get("defer_fp32_logits", None) + ), ) if load_optimizer: optimizer, scheduler = setup_optimizer( @@ -662,6 +665,9 @@ def __init__( use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, + wrap_cast_model_output_to_fp32=( + not self.cfg["megatron_cfg"].get("defer_fp32_logits", None) + ), ) print("Loading the Reference Model") if ( @@ -1113,6 +1119,7 @@ def collection_fn(output_tensor): stc = time.time() tp_grp = get_tensor_model_parallel_group() tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) if self.cfg["sequence_packing"]["enabled"]: token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( output_tensor, @@ -1124,15 +1131,17 @@ def collection_fn(output_tensor): group=tp_grp, inference_only=True, cp_group=get_context_parallel_group(), + chunk_size=logprob_chunk_size, ) else: token_logprobs = from_parallel_logits_to_logprobs( - output_tensor.to(torch.float32), + output_tensor, target=unpacked_input_ids, vocab_start_index=tp_rank * output_tensor.shape[-1], vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], tp_group=tp_grp, inference_only=True, + chunk_size=logprob_chunk_size, ) # Prepend 0 logprob for first token to maintain same sequence length as input diff --git a/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh b/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh new file mode 100755 index 0000000000..993d541871 --- /dev/null +++ b/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh @@ -0,0 +1,39 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=4 +STEPS_PER_RUN=3 +MAX_STEPS=3 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["$MAX_STEPS"] < 1.1' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 07c3eb5b9c..106a18738a 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -18,6 +18,9 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh +# GRPO math test run (32K context mcore) +tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh + ####### # SFT # ####### diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 2f8ef2011a..371080a384 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -18,6 +18,7 @@ import torch from nemo_rl.distributed.model_utils import ( + ChunkedDistributedLogprob, DistributedLogprob, _compute_distributed_log_softmax, _get_tokens_on_this_cp_rank, @@ -428,8 +429,9 @@ def test_allgather_cp_sharded_tensor(register_allgather_cp_test_actor, cp_size): @ray.remote(num_gpus=1) class DistributedLogprobTestActor: - def __init__(self, tp_size): + def __init__(self, tp_size, chunk_size): self.tp_size = tp_size + self.chunk_size = chunk_size self.env_vars = dict(os.environ) torch.distributed.init_process_group(backend="nccl") self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) @@ -455,6 +457,7 @@ def test_distributed_logprob_forward_and_backward(self): seq_len = 8 full_vocab_size = 1024 vocab_part_size = full_vocab_size // self.tp_size + chunk_size = self.chunk_size # Calculate vocab partition for this rank vocab_start_index = rank * vocab_part_size @@ -490,14 +493,25 @@ def test_distributed_logprob_forward_and_backward(self): ) # Compute using DistributedLogprob (forward only first) - distributed_log_probs_inference = DistributedLogprob.apply( - vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test - target, - vocab_start_index, - vocab_end_index, - self.tp_group, - True, # inference_only=True for forward test - ) + if chunk_size is not None: + distributed_log_probs_inference = ChunkedDistributedLogprob.apply( + vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test + target, + vocab_start_index, + vocab_end_index, + chunk_size, + self.tp_group, + True, # inference_only=True for forward test + ) + else: + distributed_log_probs_inference = DistributedLogprob.apply( + vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + True, # inference_only=True for forward test + ) # Compare forward results torch.testing.assert_close( @@ -700,9 +714,17 @@ def register_distributed_logprob_test_actor(): ) -@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "tp_size, chunk_size", + [ + (1, None), + (2, None), + (1, 4), + (2, 4), + ], +) def test_distributed_logprob_all_tests( - register_distributed_logprob_test_actor, tp_size + register_distributed_logprob_test_actor, tp_size, chunk_size ): """Test all DistributedLogprob functionality for a given TP size.""" # Skip if not enough GPUs @@ -718,7 +740,7 @@ def test_distributed_logprob_all_tests( # Create sharding for TP sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) - builder = RayWorkerBuilder(actor_fqn, tp_size) + builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size) worker_group = RayWorkerGroup( cluster=cluster, @@ -728,7 +750,9 @@ def test_distributed_logprob_all_tests( ) # Test 1: Combined Forward and Backward pass - print(f"\n=== Testing TP={tp_size}: Forward & Backward Pass ===") + print( + f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Forward & Backward Pass ===" + ) futures = worker_group.run_all_workers_single_data( "test_distributed_logprob_forward_and_backward" ) @@ -743,7 +767,7 @@ def test_distributed_logprob_all_tests( ) # Test 2: Log softmax function - print(f"\n=== Testing TP={tp_size}: Log Softmax ===") + print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Log Softmax ===") futures = worker_group.run_all_workers_single_data( "test_distributed_log_softmax" ) @@ -756,7 +780,7 @@ def test_distributed_logprob_all_tests( # Test 3: Edge cases (only for TP=2) if tp_size == 2: - print(f"\n=== Testing TP={tp_size}: Edge Cases ===") + print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Edge Cases ===") futures = worker_group.run_all_workers_single_data("test_edge_cases") results = ray.get(futures) print("Edge cases test completed successfully") diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 5b16b0b28a..cd287a6370 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -13,6 +13,7 @@ # limitations under the License. import os import tempfile +from typing import Optional import pytest import torch @@ -40,6 +41,8 @@ def create_megatron_test_config( generation_backend: str = "megatron", sequence_parallel: bool = False, converter_type: str = "LlamaForCausalLM", + logprob_chunk_size: Optional[int] = None, + defer_fp32_logits: Optional[bool] = None, ) -> PolicyConfig: """Create a test config for Megatron policy worker.""" return { @@ -50,6 +53,7 @@ def create_megatron_test_config( "train_micro_batch_size": 2, "learning_rate": 5e-6, "logprob_batch_size": 2, + "logprob_chunk_size": logprob_chunk_size, "precision": precision, "generation": { "backend": generation_backend, @@ -95,6 +99,7 @@ def create_megatron_test_config( "moe_router_load_balancing_type": "none", "moe_router_bias_update_rate": 0.0, "apply_rope_fusion": True, + "defer_fp32_logits": defer_fp32_logits, "optimizer": { "optimizer": "adam", "lr": 5.0e-6, @@ -554,9 +559,23 @@ def logprob_setup(request): """Setup and teardown specifically for logprob tests.""" # Parse parameters: (num_gpus, tp, pp, model_fixture_name) if hasattr(request, "param") and request.param is not None: - num_gpus, tp, pp, model_fixture_name = request.param + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + defer_fp32_logits, + model_fixture_name, + ) = request.param else: - num_gpus, tp, pp, model_fixture_name = 2, 1, 1, "tiny_llama_model_path" + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + defer_fp32_logits, + model_fixture_name, + ) = (2, 1, 1, None, None, "tiny_llama_model_path") # Get the actual model path from the requested fixture model_name = request.getfixturevalue(model_fixture_name) @@ -591,6 +610,8 @@ def logprob_setup(request): tp=tp, pp=pp, converter_type=converter_type, + logprob_chunk_size=logprob_chunk_size, + defer_fp32_logits=defer_fp32_logits, ) tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config( @@ -639,14 +660,35 @@ def logprob_setup(request): @pytest.mark.parametrize( "logprob_setup", [ - # (num_gpus, tp, pp, model_fixture_name) - (2, 1, 1, "tiny_llama_model_path"), - (2, 2, 1, "tiny_llama_model_path"), - (2, 1, 1, "tiny_qwen2_model_path"), - (2, 2, 1, "tiny_qwen2_model_path"), + # (num_gpus, tp, pp, chunk sz, defer fp32, model_fixture_name) + (2, 1, 1, None, None, "tiny_llama_model_path"), + (2, 2, 1, None, None, "tiny_llama_model_path"), + (2, 1, 1, None, None, "tiny_qwen2_model_path"), + (2, 2, 1, None, None, "tiny_qwen2_model_path"), + (2, 1, 1, None, True, "tiny_llama_model_path"), + (2, 2, 1, None, True, "tiny_llama_model_path"), + (2, 1, 1, None, True, "tiny_qwen2_model_path"), + (2, 2, 1, None, True, "tiny_qwen2_model_path"), + (2, 1, 1, 16, True, "tiny_llama_model_path"), + (2, 2, 1, 16, True, "tiny_llama_model_path"), + (2, 1, 1, 16, True, "tiny_qwen2_model_path"), + (2, 2, 1, 16, True, "tiny_qwen2_model_path"), ], indirect=True, - ids=["2gpu_dp2_llama", "2gpu_tp2_llama", "2gpu_dp2_qwen2", "2gpu_tp2_qwen2"], + ids=[ + "2gpu_dp2_llama", + "2gpu_tp2_llama", + "2gpu_dp2_qwen2", + "2gpu_tp2_qwen2", + "2gpu_dp2_deferfp32_llama", + "2gpu_tp2_deferfp32_llama", + "2gpu_dp2_deferfp32_qwen2", + "2gpu_tp2_deferfp32_qwen2", + "2gpu_dp2_chunked_deferfp32_llama", + "2gpu_tp2_chunked_deferfp32_llama", + "2gpu_dp2_chunked_deferfp32_qwen2", + "2gpu_tp2_chunked_deferfp32_qwen2", + ], ) def test_megatron_policy_logprobs(logprob_setup): """Test Megatron policy logprob computation.""" @@ -663,6 +705,7 @@ def test_megatron_policy_logprobs(logprob_setup): # Basic validation assert isinstance(policy_logprobs, torch.Tensor), "Logprobs should be a tensor" + assert policy_logprobs.dtype == torch.float32 assert policy_logprobs.shape == data.get("input_ids").shape, ( f"Logprobs shape {policy_logprobs.shape} should match input shape {data.get('input_ids').shape}" )