From a99900fc1208f77715bc0b361a690ba1892ea749 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 13:18:09 -0700 Subject: [PATCH 01/38] Port chunked logprob and deferred float32 logits (WIP). Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 8 +- nemo_rl/distributed/model_utils.py | 190 +++++++++++++++--- nemo_rl/models/dtensor/parallelize.py | 2 + .../models/policy/dtensor_policy_worker.py | 10 +- .../models/policy/megatron_policy_worker.py | 4 + 5 files changed, 181 insertions(+), 33 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index ed778dc392..654eeda5bf 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -137,8 +137,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( "vocab_parallel_rank must be provided when vocab_parallel_group is provided" @@ -159,6 +157,7 @@ def __call__( next_token_logits, data["input_ids"], seq_index=seq_index ) else: + next_token_logits = next_token_logits.to(torch.float32) next_token_logits_wo_last = next_token_logits[ :, :-1 ] # Remove last position's logits @@ -326,8 +325,6 @@ def __call__( sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - next_token_logits = next_token_logits.to(torch.float32) - # Gather the logprobs for the actual next tokens if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( @@ -350,6 +347,7 @@ def __call__( ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logits = next_token_logits.to(torch.float32) next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) @@ -581,7 +579,6 @@ def _dpo_loss( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - next_token_logits = next_token_logits.to(torch.float32) if vocab_parallel_group is not None: assert vocab_parallel_rank is not None, ( "vocab_parallel_rank must be provided when vocab_parallel_group is provided" @@ -603,6 +600,7 @@ def _dpo_loss( ) else: next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logits = next_token_logits.to(torch.float32) next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 5b6a2d57f2..64ed5dd41d 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -77,6 +77,8 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func masked_target = target - vocab_start_index masked_target[target_mask] = 0 + vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) + log_softmax_output = _compute_distributed_log_softmax( vocab_parallel_logits, group=group ) @@ -141,6 +143,120 @@ def backward( return grad_input, None, None, None, None, None, None +class ChunkedDistributedLogprob(torch.autograd.Function): + """Custom autograd function for computing log probabilities in a distributed setting. + The log probabilities computation is sequence chunked to mitigate device memory + usage during backward pass. + + Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 + """ + + @staticmethod + def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Function.forward's type since it's always more specific than the base class + ctx: Any, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:,chunk_start:chunk_end,:] + logits = logits.to(dtype=torch.float32) + + log_softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + log_probs = log_softmax_output.clone() + + log_probs = torch.gather( + log_probs, + -1, + masked_target[:,chunk_start:chunk_end].unsqueeze(-1) + ).squeeze(-1) + log_probs[target_mask[:,chunk_start:chunk_end]] = 0.0 + + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + + if not inference_only: + # only save for backward when we have inference only=False + ctx.save_for_backward(vocab_parallel_logits, target_mask, masked_target) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor, + ) -> tuple[torch.Tensor, None, None, None, None, None, None]: + grad_output = grad_outputs[0] + vocab_parallel_logits, target_mask, masked_target = ctx.saved_tensors + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + + partition_vocab_size = int(vocab_parallel_logits.shape[-1]) + seq_size = int(vocab_parallel_logits.shape[1]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + + all_grad_input = [] + + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + + logits = vocab_parallel_logits[:,chunk_start:chunk_end,:] + logits = logits.to(dtype=torch.float32) + + log_softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + log_probs = log_softmax_output.clone() + softmax_output = log_softmax_output.exp_() + + # 1 if it's the chosen log prob, 0 otherwise + is_chosen = (~(target_mask[:,chunk_start:chunk_end])).unsqueeze(-1) * torch.nn.functional.one_hot( + masked_target[:,chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + + grad_input = is_chosen.float().sub_(softmax_output) + + grad_input.mul_(grad_output[:,chunk_start:chunk_end].unsqueeze(dim=-1)) + + all_grad_input.append(grad_input) + + grad_input = torch.cat(all_grad_input, dim=1) + + # if you add an argument to the forward method, then you must add a corresponding None here + return grad_input, None, None, None, None, None, None + + def dtensor_from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, target: DTensor | torch.Tensor, @@ -149,6 +265,7 @@ def dtensor_from_parallel_logits_to_logprobs( tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, seq_index: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -194,23 +311,34 @@ def dtensor_from_parallel_logits_to_logprobs( else: target = target.roll(shifts=-1, dims=-1) - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: - # probs is sharded on the sequence dimension. + # logprobs is sharded on the sequence dimension. # Get full sequence tensor, vocab dim has been reduced already. - probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements) - probs = probs_dtensor.full_tensor()[:, sorted_indices] - assert probs.shape == target_shape + logprobs_dtensor = DTensor.from_local(logprobs, cp_mesh, cp_placements) + logprobs = logprobs_dtensor.full_tensor()[:, sorted_indices] + assert logprobs.shape == target_shape - return probs[:, :-1] + return logprobs[:, :-1] def from_parallel_logits_to_logprobs( @@ -221,6 +349,7 @@ def from_parallel_logits_to_logprobs( tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP+CP sharded vocab logits. @@ -254,25 +383,36 @@ def from_parallel_logits_to_logprobs( cp_rank = torch.distributed.get_rank(cp_group) target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - target, - vocab_start_index, - vocab_end_index, - tp_group, - inference_only, - ).contiguous() + if chunk_size is not None: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + else: + logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + target, + vocab_start_index, + vocab_end_index, + tp_group, + inference_only, + ).contiguous() if cp_size > 1: # we need to gather the logits by context parallelism - probs = allgather_cp_sharded_tensor( - probs, cp_group, seq_dim=1 + logprobs = allgather_cp_sharded_tensor( + logprobs, cp_group, seq_dim=1 ) # , unpadded_seqlen=target.shape[1]) if pad_len > 0: - probs = probs[:, :-pad_len] + logprobs = logprobs[:, :-pad_len] - return probs[:, :-1] + return logprobs[:, :-1] def from_parallel_logits_to_logprobs_packed_sequences( diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index e2af748d71..2171a34a9e 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -620,6 +620,7 @@ def get_logprobs_from_vocab_parallel_logits( vocab_parallel_logits: DTensor, input_ids: torch.Tensor | DTensor, seq_index: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -660,4 +661,5 @@ def get_logprobs_from_vocab_parallel_logits( tp_group, inference_only=not torch.is_grad_enabled(), seq_index=seq_index, + chunk_size=chunk_size, ) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 18ae23d95b..e76254315f 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -888,6 +888,7 @@ def get_logprobs( if micro_batch_size is not None else self.cfg["logprob_batch_size"] ) + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) # dim 1 is always assumed to be the sequence dim, sanity check this here sequence_dim = 1 @@ -1045,16 +1046,19 @@ def get_logprobs( ) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), + logits, input_ids_dtensor, seq_index_tensor, + chunk_size=logprob_chunk_size, ) assert token_logprobs.shape[1] == seq_len - 1 else: if isinstance(logits, DTensor): token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), input_ids + logits, + input_ids, + chunk_size=logprob_chunk_size, ) else: # Extract logprobs for each token in the sequence by gathering the logprob @@ -1066,7 +1070,7 @@ def get_logprobs( # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 + logits.to(torch.float32), dim=-1 ) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 35c18eb701..171868c4ee 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1088,6 +1088,7 @@ def collection_fn(output_tensor): stc = time.time() tp_grp = get_tensor_model_parallel_group() tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) if self.cfg["sequence_packing"]["enabled"]: token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( output_tensor, @@ -1099,6 +1100,8 @@ def collection_fn(output_tensor): group=tp_grp, inference_only=True, cp_group=get_context_parallel_group(), + # TODO(pjin): chunked logprob not implemented yet w/ seq packing. + # chunk_size=logprob_chunk_size, ) else: token_logprobs = from_parallel_logits_to_logprobs( @@ -1108,6 +1111,7 @@ def collection_fn(output_tensor): vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], tp_group=tp_grp, inference_only=True, + chunk_size=logprob_chunk_size, ) # Prepend 0 logprob for first token to maintain same sequence length as input From 37a027a6a63b1facb1720c11264d825fa626f2bb Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 13:39:09 -0700 Subject: [PATCH 02/38] Add copy of nemo.tron.model without logits float32 cast. Based on NeMo commit: 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 Signed-off-by: Peter Jin --- .../models/policy/megatron_policy_worker.py | 5 +- nemo_rl/tron/__init__.py | 0 nemo_rl/tron/model.py | 164 ++++++++++++++++++ 3 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 nemo_rl/tron/__init__.py create mode 100644 nemo_rl/tron/model.py diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 171868c4ee..9b6994113c 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -124,6 +124,7 @@ get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) +from nemo_rl.tron.model import get_model_from_config_no_float32 TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -209,7 +210,7 @@ def re_enable_float32_expert_bias(model_module): model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. - model = get_model_from_config( + model = get_model_from_config_no_float32( cfg.model_config, cfg.ddp_config, use_torch_fsdp2=cfg.dist_config.use_torch_fsdp2, @@ -645,7 +646,7 @@ def __init__( ref_state = GlobalState() ref_state.cfg = ref_megatron_cfg - reference_model = get_model_from_config( + reference_model = get_model_from_config_no_float32( self.megatron_cfg.model_config, self.megatron_cfg.ddp_config, use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, diff --git a/nemo_rl/tron/__init__.py b/nemo_rl/tron/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_rl/tron/model.py b/nemo_rl/tron/model.py new file mode 100644 index 0000000000..a3a9ee352c --- /dev/null +++ b/nemo_rl/tron/model.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import torch +from megatron.core import parallel_state, tensor_parallel +from megatron.core.distributed import ( + DistributedDataParallel, + DistributedDataParallelConfig, + TorchFullyShardedDataParallel, +) +from megatron.core.enums import ModelType +from megatron.core.fp8_utils import is_float8tensor +# from megatron.core.transformer.module import Float16Module + +from nemo.collections.llm.gpt.model.base import GPTConfig +from nemo.collections.llm.t5.model.t5 import T5Config + + +def get_model_from_config_no_float32( + model_config: GPTConfig | T5Config, + ddp_config: DistributedDataParallelConfig, + overlap_param_gather_with_optimizer_step: bool = False, + use_torch_fsdp2: bool = False, + wrap_with_ddp: bool = True, + data_parallel_random_init: bool = True, + model_post_init_fns: Optional[list[Callable]] = None, +): + # This method should only be called after `init_distributed()`. + # model_provider_func is equivalent to llm.gpt.GPTConfig.configure_model() + # model_type is inferred from the model_config class + + model_type = _get_model_type(model_config) + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None + ): + assert ( + model_type != ModelType.encoder_and_decoder + ), "Interleaved schedule not supported for model with both encoder and decoder" + model = [] + for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + this_model = model_config.configure_model( + tokenizer=None, + pre_process=pre_process, + post_process=post_process, + ) + this_model.model_type = model_type + model.append(this_model) + else: + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + if model_type == ModelType.encoder_and_decoder: + assert isinstance(model_config, T5Config) + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + rank = parallel_state.get_pipeline_model_parallel_rank() + first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() + world_size = parallel_state.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == first_decoder_rank + post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) + model = model_config.configure_model( + tokenizer=None, + ) + else: + model = model_config.configure_model( + tokenizer=None, + pre_process=pre_process, + post_process=post_process, + ) + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if parallel_state.get_data_parallel_rank() == 0: + print( + " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_pipeline_model_parallel_rank(), + sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), + ), + flush=True, + ) + + # GPU allocation. + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if model_config.fp16 or model_config.bf16: + # TODO(pjin): okay to skip, as long as logits are float32-casted elsewhere. + # model = [Float16Module(model_config, model_module) for model_module in model] + pass + + if model_post_init_fns: + for model_module in model: + for post_init_fn in model_post_init_fns: + post_init_fn(model_module) + + # The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's + # Float8Tensor, which will write an unwanted value (amax calculated from the current fp8 + # param) to its amax_history. The following logic will correct the amax_history back. + for model_module in model: + for param in model_module.parameters(): + if is_float8tensor(param) and param._fp8_meta is not None: + fp8_meta = param._fp8_meta["scaling_fwd"] + fp8_meta_index = param._fp8_meta_index + if hasattr(param, "get_high_precision_init_val"): + fp8_meta.amax_history[0][fp8_meta_index].copy_(param.get_high_precision_init_val().abs().max()) + else: + fp8_meta.amax_history[0][fp8_meta_index] = 0 + + if wrap_with_ddp: + if use_torch_fsdp2: + DP = TorchFullyShardedDataParallel + else: + DP = DistributedDataParallel + + model = [ + DP( + config=model_config, + ddp_config=ddp_config, + module=model_chunk, + # Turn off bucketing for model_chunk 2 onwards, since communication for these + # model chunks is overlapped with compute anyway. + disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, + ) + for (model_chunk_idx, model_chunk) in enumerate(model) + ] + + # Broadcast params from data parallel src rank to other data parallel ranks. + if data_parallel_random_init: + for model_module in model: + model_module.broadcast_params() + return model + + +def _get_model_type(model_config: GPTConfig | T5Config) -> ModelType: + return ModelType.encoder_and_decoder if isinstance(model_config, T5Config) else ModelType.encoder_or_decoder From 31707dbc069741b9fd226b0dc524fb0f1fa5e338 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 13:57:40 -0700 Subject: [PATCH 03/38] Fix. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 64ed5dd41d..e9675b3680 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -384,7 +384,7 @@ def from_parallel_logits_to_logprobs( target = _get_tokens_on_this_cp_rank(target, cp_rank, cp_size, seq_dim=1) if chunk_size is not None: - logprobs: torch.Tensor = DistributedLogprob.apply( # type: ignore + logprobs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore vocab_parallel_logits, target, vocab_start_index, From 6a445bc43c715c6954bafc697f32e26f9a614fe3 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 14:46:33 -0700 Subject: [PATCH 04/38] Ruff + doc comment. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 25 +++++++++++--------- nemo_rl/tron/model.py | 38 ++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index e9675b3680..24957536d5 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -145,8 +145,11 @@ def backward( class ChunkedDistributedLogprob(torch.autograd.Function): """Custom autograd function for computing log probabilities in a distributed setting. - The log probabilities computation is sequence chunked to mitigate device memory - usage during backward pass. + + The log probabilities computation is chunked in the sequence dimension + to mitigate GPU OOM (especially during backward pass). + In addition, logits casting from float16 or bfloat16 -> float32 is performed + inside the chunk loop to avoid materializing a whole float32 logits tensor. Adapted from https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L286 """ @@ -175,7 +178,7 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func chunk_start = chunk_idx * chunk_size chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) - logits = vocab_parallel_logits[:,chunk_start:chunk_end,:] + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) log_softmax_output = _compute_distributed_log_softmax( @@ -185,11 +188,9 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func log_probs = log_softmax_output.clone() log_probs = torch.gather( - log_probs, - -1, - masked_target[:,chunk_start:chunk_end].unsqueeze(-1) + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) ).squeeze(-1) - log_probs[target_mask[:,chunk_start:chunk_end]] = 0.0 + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 torch.distributed.all_reduce( log_probs, @@ -229,7 +230,7 @@ def backward( chunk_start = chunk_idx * chunk_size chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) - logits = vocab_parallel_logits[:,chunk_start:chunk_end,:] + logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) log_softmax_output = _compute_distributed_log_softmax( @@ -240,14 +241,16 @@ def backward( softmax_output = log_softmax_output.exp_() # 1 if it's the chosen log prob, 0 otherwise - is_chosen = (~(target_mask[:,chunk_start:chunk_end])).unsqueeze(-1) * torch.nn.functional.one_hot( - masked_target[:,chunk_start:chunk_end], + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], num_classes=partition_vocab_size, ) grad_input = is_chosen.float().sub_(softmax_output) - grad_input.mul_(grad_output[:,chunk_start:chunk_end].unsqueeze(dim=-1)) + grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)) all_grad_input.append(grad_input) diff --git a/nemo_rl/tron/model.py b/nemo_rl/tron/model.py index a3a9ee352c..f625d6db66 100644 --- a/nemo_rl/tron/model.py +++ b/nemo_rl/tron/model.py @@ -47,9 +47,9 @@ def get_model_from_config_no_float32( parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None ): - assert ( - model_type != ModelType.encoder_and_decoder - ), "Interleaved schedule not supported for model with both encoder and decoder" + assert model_type != ModelType.encoder_and_decoder, ( + "Interleaved schedule not supported for model with both encoder and decoder" + ) model = [] for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): parallel_state.set_virtual_pipeline_model_parallel_rank(i) @@ -70,10 +70,14 @@ def get_model_from_config_no_float32( assert isinstance(model_config, T5Config) if parallel_state.get_pipeline_model_parallel_world_size() > 1: rank = parallel_state.get_pipeline_model_parallel_rank() - first_decoder_rank = parallel_state.get_pipeline_model_parallel_decoder_start() + first_decoder_rank = ( + parallel_state.get_pipeline_model_parallel_decoder_start() + ) world_size = parallel_state.get_pipeline_model_parallel_world_size() pre_process = rank == 0 or rank == first_decoder_rank - post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1)) + post_process = (rank == (first_decoder_rank - 1)) or ( + rank == (world_size - 1) + ) model = model_config.configure_model( tokenizer=None, ) @@ -94,7 +98,9 @@ def get_model_from_config_no_float32( # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes( + param + ) # Print number of parameters. if parallel_state.get_data_parallel_rank() == 0: @@ -102,7 +108,12 @@ def get_model_from_config_no_float32( " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( parallel_state.get_tensor_model_parallel_rank(), parallel_state.get_pipeline_model_parallel_rank(), - sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), + sum( + [ + sum([p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ), ), flush=True, ) @@ -131,7 +142,9 @@ def get_model_from_config_no_float32( fp8_meta = param._fp8_meta["scaling_fwd"] fp8_meta_index = param._fp8_meta_index if hasattr(param, "get_high_precision_init_val"): - fp8_meta.amax_history[0][fp8_meta_index].copy_(param.get_high_precision_init_val().abs().max()) + fp8_meta.amax_history[0][fp8_meta_index].copy_( + param.get_high_precision_init_val().abs().max() + ) else: fp8_meta.amax_history[0][fp8_meta_index] = 0 @@ -148,7 +161,8 @@ def get_model_from_config_no_float32( module=model_chunk, # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, + disable_bucketing=(model_chunk_idx > 0) + or overlap_param_gather_with_optimizer_step, ) for (model_chunk_idx, model_chunk) in enumerate(model) ] @@ -161,4 +175,8 @@ def get_model_from_config_no_float32( def _get_model_type(model_config: GPTConfig | T5Config) -> ModelType: - return ModelType.encoder_and_decoder if isinstance(model_config, T5Config) else ModelType.encoder_or_decoder + return ( + ModelType.encoder_and_decoder + if isinstance(model_config, T5Config) + else ModelType.encoder_or_decoder + ) From a020289609cfa0d7a695a175eed009fdb4695088 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 15:05:29 -0700 Subject: [PATCH 05/38] Configurable deferring float32 logits. Signed-off-by: Peter Jin --- ...po_math_qwen3_30ba3b_megatron_tp4_32k.yaml | 110 ++++++++++++++++++ .../models/policy/megatron_policy_worker.py | 6 +- nemo_rl/tron/model.py | 1 - 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml diff --git a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml new file mode 100644 index 0000000000..5ae601a97b --- /dev/null +++ b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml @@ -0,0 +1,110 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_1B_megatron.yaml" + +checkpointing: + enabled: True + save_period: 4 + keep_top_k: 100 + +grpo: + num_prompts_per_step: 64 + num_generations_per_prompt: 16 + val_period: 4 + val_at_start: True + # max_val_samples: 256 + # val_batch_size: 256 + val_num_generations_per_prompt: 16 + +loss_fn: + sequence_level_importance_sampling: False + sequence_level_ratio_clip: False + +policy: + model_name: "Qwen/Qwen3-30B-A3B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 1 + max_total_sequence_length: 32768 + precision: "bfloat16" + activation_checkpointing_enabled: True + logprob_chunk_size: 2048 + + dtensor_cfg: + enabled: false + + sequence_packing: + enabled: False + + optimizer: null # remove default FSDP optimizer + + scheduler: null # remove default FSDP scheduler + + megatron_cfg: + enabled: true + empty_unused_memory_level: 1 + converter_type: "LlamaForCausalLM" + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + # NB(peter): AssertionError: Sequence Packing must be enabled to use Context Parallelism with MCore + # context_parallel_size: 2 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: True + pipeline_dtype: ${policy.precision} + activation_checkpointing: True + deferred_fp32_logits: True + + optimizer: + optimizer: "adam" + # lr: 3.0e-7 + # min_lr: 3.0e-8 + lr: 5.0e-7 + min_lr: 5.0e-8 + weight_decay: 0.0 + # weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + # lr_warmup_iters: 13 + # lr_warmup_init: 3.0e-8 + lr_warmup_iters: 4 + lr_warmup_init: 5.0e-8 + + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.6 + enforce_eager: false + max_model_len: ${policy.max_total_sequence_length} + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + # prompt_file: null + system_prompt_file: null + # dataset_name: null + +cluster: + gpus_per_node: 8 + num_nodes: 4 diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9b6994113c..40e4debc08 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -210,7 +210,11 @@ def re_enable_float32_expert_bias(model_module): model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. - model = get_model_from_config_no_float32( + if policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None): + model_builder = get_model_from_config_no_float32 + else: + model_builder = get_model_from_config + model = model_builder( cfg.model_config, cfg.ddp_config, use_torch_fsdp2=cfg.dist_config.use_torch_fsdp2, diff --git a/nemo_rl/tron/model.py b/nemo_rl/tron/model.py index f625d6db66..af3c66eae3 100644 --- a/nemo_rl/tron/model.py +++ b/nemo_rl/tron/model.py @@ -23,7 +23,6 @@ ) from megatron.core.enums import ModelType from megatron.core.fp8_utils import is_float8tensor -# from megatron.core.transformer.module import Float16Module from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.collections.llm.t5.model.t5 import T5Config From 956051c1eb51defe52b73a93095b79004c92e4e3 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 15:11:30 -0700 Subject: [PATCH 06/38] Update docstrings. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 2 ++ nemo_rl/models/dtensor/parallelize.py | 1 + 2 files changed, 3 insertions(+) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 24957536d5..074a1941f4 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -283,6 +283,7 @@ def dtensor_from_parallel_logits_to_logprobs( inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. + chunk_size (Optional[int]): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -366,6 +367,7 @@ def from_parallel_logits_to_logprobs( tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 2171a34a9e..d2bc82d1c3 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -634,6 +634,7 @@ def get_logprobs_from_vocab_parallel_logits( with shape [batch_size, seq_len]. seq_index (Optional[torch.Tensor]): Sequence index for the input IDs, with shape [sequence_length]. + chunk_size (Optional[int]): Sequence dimension chunk size for computing log probabilities. Returns: torch.Tensor: Log probabilities for the given input IDs. From ece804980b90962188eefa09ae0c21bc047cb60f Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 15:22:17 -0700 Subject: [PATCH 07/38] Ruff. Signed-off-by: Peter Jin --- nemo_rl/tron/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/tron/model.py b/nemo_rl/tron/model.py index af3c66eae3..01d6c3db0d 100644 --- a/nemo_rl/tron/model.py +++ b/nemo_rl/tron/model.py @@ -23,7 +23,6 @@ ) from megatron.core.enums import ModelType from megatron.core.fp8_utils import is_float8tensor - from nemo.collections.llm.gpt.model.base import GPTConfig from nemo.collections.llm.t5.model.t5 import T5Config From 670743fae73daf91a8bd1c67098b29360802b085 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 15:50:53 -0700 Subject: [PATCH 08/38] Basic chunking support in logprobs computation with sequence packing. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 29 ++++++++++++++----- .../models/policy/megatron_policy_worker.py | 3 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 074a1941f4..de9cedaaba 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -430,6 +430,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( group: torch.distributed.ProcessGroup, inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits for packed sequences. @@ -446,6 +447,7 @@ def from_parallel_logits_to_logprobs_packed_sequences( group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None. + chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities. Returns: torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1]. @@ -479,14 +481,25 @@ def from_parallel_logits_to_logprobs_packed_sequences( vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0) # Apply distributed log probability computation - probs: torch.Tensor = DistributedLogprob.apply( # type: ignore - vocab_parallel_logits, - rolled_targets, - vocab_start_index, - vocab_end_index, - group, - inference_only, - ).contiguous() + if chunk_size is not None: + probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + chunk_size, + group, + inference_only, + ).contiguous() + else: + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore + vocab_parallel_logits, + rolled_targets, + vocab_start_index, + vocab_end_index, + group, + inference_only, + ).contiguous() # Remove batch dimension for filtering probs = probs.squeeze(0) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 40e4debc08..1ca0eb935e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1105,8 +1105,7 @@ def collection_fn(output_tensor): group=tp_grp, inference_only=True, cp_group=get_context_parallel_group(), - # TODO(pjin): chunked logprob not implemented yet w/ seq packing. - # chunk_size=logprob_chunk_size, + chunk_size=logprob_chunk_size, ) else: token_logprobs = from_parallel_logits_to_logprobs( From f1a9d21c64359729209d2111f41e2f72a660c3ea Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:09:37 -0700 Subject: [PATCH 09/38] Unit test for chunked logprobs. Signed-off-by: Peter Jin --- tests/unit/distributed/test_model_utils.py | 36 ++++++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 2f8ef2011a..785b02107f 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -18,6 +18,7 @@ import torch from nemo_rl.distributed.model_utils import ( + ChunkedDistributedLogprob, DistributedLogprob, _compute_distributed_log_softmax, _get_tokens_on_this_cp_rank, @@ -428,8 +429,9 @@ def test_allgather_cp_sharded_tensor(register_allgather_cp_test_actor, cp_size): @ray.remote(num_gpus=1) class DistributedLogprobTestActor: - def __init__(self, tp_size): + def __init__(self, tp_size, chunk_size): self.tp_size = tp_size + self.chunk_size = chunk_size self.env_vars = dict(os.environ) torch.distributed.init_process_group(backend="nccl") self.tp_group = torch.distributed.new_group(ranks=list(range(tp_size))) @@ -455,6 +457,7 @@ def test_distributed_logprob_forward_and_backward(self): seq_len = 8 full_vocab_size = 1024 vocab_part_size = full_vocab_size // self.tp_size + chunk_size = self.chunk_size # Calculate vocab partition for this rank vocab_start_index = rank * vocab_part_size @@ -490,14 +493,25 @@ def test_distributed_logprob_forward_and_backward(self): ) # Compute using DistributedLogprob (forward only first) - distributed_log_probs_inference = DistributedLogprob.apply( - vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test - target, - vocab_start_index, - vocab_end_index, - self.tp_group, - True, # inference_only=True for forward test - ) + if chunk_size is not None: + distributed_log_probs_inference = ChunkedDistributedLogprob.apply( + vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test + target, + vocab_start_index, + vocab_end_index, + chunk_size, + self.tp_group, + True, # inference_only=True for forward test + ) + else: + distributed_log_probs_inference = DistributedLogprob.apply( + vocab_parallel_logits.clone().detach(), # Clone to avoid affecting backward test + target, + vocab_start_index, + vocab_end_index, + self.tp_group, + True, # inference_only=True for forward test + ) # Compare forward results torch.testing.assert_close( @@ -700,7 +714,9 @@ def register_distributed_logprob_test_actor(): ) -@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("tp_size, chunk_size", [ + (1, None), (2, None), (1, 4), (2, 4), +]) def test_distributed_logprob_all_tests( register_distributed_logprob_test_actor, tp_size ): From df7071535270b4c05a988c171c3d1f59793cf32d Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:14:27 -0700 Subject: [PATCH 10/38] Ruff. Signed-off-by: Peter Jin --- tests/unit/distributed/test_model_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 785b02107f..bbd49e68f8 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -714,9 +714,15 @@ def register_distributed_logprob_test_actor(): ) -@pytest.mark.parametrize("tp_size, chunk_size", [ - (1, None), (2, None), (1, 4), (2, 4), -]) +@pytest.mark.parametrize( + "tp_size, chunk_size", + [ + (1, None), + (2, None), + (1, 4), + (2, 4), + ], +) def test_distributed_logprob_all_tests( register_distributed_logprob_test_actor, tp_size ): From 637e13198946c383ce3f56616470b3fb5cbee6a1 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:20:52 -0700 Subject: [PATCH 11/38] Pyrefly. Signed-off-by: Peter Jin --- pyrefly.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrefly.toml b/pyrefly.toml index 6442672edf..84168857c6 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -98,6 +98,7 @@ project-includes = [ "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", + "nemo_rl/tron/__init__.py", "nemo_rl/utils/__init__.py", "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", From abbf796fbe610d12a63c90ef8af1b53f7e01fad2 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:27:02 -0700 Subject: [PATCH 12/38] Fix test. Pyrefly. Signed-off-by: Peter Jin --- pyrefly.toml | 1 + tests/unit/distributed/test_model_utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyrefly.toml b/pyrefly.toml index 84168857c6..1149c4086a 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -99,6 +99,7 @@ project-includes = [ "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", "nemo_rl/tron/__init__.py", + "nemo_rl/tron/model.py", "nemo_rl/utils/__init__.py", "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index bbd49e68f8..3cc0bb36f6 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -724,7 +724,7 @@ def register_distributed_logprob_test_actor(): ], ) def test_distributed_logprob_all_tests( - register_distributed_logprob_test_actor, tp_size + register_distributed_logprob_test_actor, tp_size, chunk_size ): """Test all DistributedLogprob functionality for a given TP size.""" # Skip if not enough GPUs @@ -740,7 +740,7 @@ def test_distributed_logprob_all_tests( # Create sharding for TP sharding = NamedSharding(layout=list(range(tp_size)), names=["tp"]) - builder = RayWorkerBuilder(actor_fqn, tp_size) + builder = RayWorkerBuilder(actor_fqn, tp_size, chunk_size) worker_group = RayWorkerGroup( cluster=cluster, @@ -750,7 +750,7 @@ def test_distributed_logprob_all_tests( ) # Test 1: Combined Forward and Backward pass - print(f"\n=== Testing TP={tp_size}: Forward & Backward Pass ===") + print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Forward & Backward Pass ===") futures = worker_group.run_all_workers_single_data( "test_distributed_logprob_forward_and_backward" ) @@ -765,7 +765,7 @@ def test_distributed_logprob_all_tests( ) # Test 2: Log softmax function - print(f"\n=== Testing TP={tp_size}: Log Softmax ===") + print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Log Softmax ===") futures = worker_group.run_all_workers_single_data( "test_distributed_log_softmax" ) @@ -778,7 +778,7 @@ def test_distributed_logprob_all_tests( # Test 3: Edge cases (only for TP=2) if tp_size == 2: - print(f"\n=== Testing TP={tp_size}: Edge Cases ===") + print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Edge Cases ===") futures = worker_group.run_all_workers_single_data("test_edge_cases") results = ray.get(futures) print("Edge cases test completed successfully") From 985ba775b926c88aa7581ea1a8fe8f77368001ce Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:29:48 -0700 Subject: [PATCH 13/38] Ruff. Signed-off-by: Peter Jin --- tests/unit/distributed/test_model_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/distributed/test_model_utils.py b/tests/unit/distributed/test_model_utils.py index 3cc0bb36f6..371080a384 100644 --- a/tests/unit/distributed/test_model_utils.py +++ b/tests/unit/distributed/test_model_utils.py @@ -750,7 +750,9 @@ def test_distributed_logprob_all_tests( ) # Test 1: Combined Forward and Backward pass - print(f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Forward & Backward Pass ===") + print( + f"\n=== Testing TP={tp_size} ChunkSize={chunk_size}: Forward & Backward Pass ===" + ) futures = worker_group.run_all_workers_single_data( "test_distributed_logprob_forward_and_backward" ) From 13265aa63010fc34b35ea84362977e229354565f Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 16:50:00 -0700 Subject: [PATCH 14/38] Stale comment. Signed-off-by: Peter Jin --- examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml index 5ae601a97b..922736c5cc 100644 --- a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml +++ b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml @@ -49,8 +49,6 @@ policy: tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 context_parallel_size: 1 - # NB(peter): AssertionError: Sequence Packing must be enabled to use Context Parallelism with MCore - # context_parallel_size: 2 expert_tensor_parallel_size: 1 expert_model_parallel_size: 8 sequence_parallel: True From 9758b14040231d477084b79e708d27b77969c5b5 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 17:48:55 -0700 Subject: [PATCH 15/38] Remove unused config. Signed-off-by: Peter Jin --- examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml index 922736c5cc..26f61d138f 100644 --- a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml +++ b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml @@ -15,10 +15,6 @@ grpo: # val_batch_size: 256 val_num_generations_per_prompt: 16 -loss_fn: - sequence_level_importance_sampling: False - sequence_level_ratio_clip: False - policy: model_name: "Qwen/Qwen3-30B-A3B" tokenizer: From ea5171541910de452bfc5f33cfa3b0ffdce9c854 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 17:50:50 -0700 Subject: [PATCH 16/38] Remove unused config. Signed-off-by: Peter Jin --- examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml index 26f61d138f..d9bc9ef1cc 100644 --- a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml +++ b/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml @@ -13,7 +13,6 @@ grpo: val_at_start: True # max_val_samples: 256 # val_batch_size: 256 - val_num_generations_per_prompt: 16 policy: model_name: "Qwen/Qwen3-30B-A3B" From 1670f93b9be9926b1f9d9c255a643c85fffb677a Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 17:57:15 -0700 Subject: [PATCH 17/38] Also apply to the reference model. Signed-off-by: Peter Jin --- nemo_rl/models/policy/megatron_policy_worker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 1ca0eb935e..947584e9f8 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -650,7 +650,11 @@ def __init__( ref_state = GlobalState() ref_state.cfg = ref_megatron_cfg - reference_model = get_model_from_config_no_float32( + if self.cfg["megatron_cfg"].get("deferred_fp32_logits", None): + ref_model_builder = get_model_from_config_no_float32 + else: + ref_model_builder = get_model_from_config + reference_model = ref_model_builder( self.megatron_cfg.model_config, self.megatron_cfg.ddp_config, use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, From 8aef5edbeb139da877ac9c6b7c006da76d76c53a Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 19:57:23 -0700 Subject: [PATCH 18/38] Typed policy configs. Signed-off-by: Peter Jin --- nemo_rl/models/policy/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index c5637d7096..507f3f0e83 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -93,6 +93,7 @@ class MegatronConfig(TypedDict): freeze_moe_router: bool expert_tensor_parallel_size: int expert_model_parallel_size: int + deferred_fp32_logits: NotRequired[bool] optimizer: NotRequired[MegatronOptimizerConfig] scheduler: NotRequired[MegatronSchedulerConfig] @@ -138,6 +139,7 @@ class PolicyConfig(TypedDict): train_global_batch_size: int train_micro_batch_size: int logprob_batch_size: NotRequired[int] + logprob_chunk_size: NotRequired[int] generation: NotRequired[GenerationConfig] generation_batch_size: NotRequired[ int From 830debecd3538e3155f4a44afcd79101653a5fc4 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 10:21:56 -0700 Subject: [PATCH 19/38] Bump NeMo submodule. Signed-off-by: Peter Jin --- .gitmodules | 2 +- 3rdparty/NeMo-workspace/NeMo | 2 +- .../llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename examples/configs/{grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml => recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml} (100%) diff --git a/.gitmodules b/.gitmodules index 09342d3495..47c96837ba 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/NeMo"] path = 3rdparty/NeMo-workspace/NeMo url = https://github.com/NVIDIA/NeMo.git - branch = zhiyul/yukih/prepare-refit-info + branch = pjin/nemorl-logprob shallow = true [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 8ddf438734..0bf0dbce6c 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 +Subproject commit 0bf0dbce6c6794efc0df4d0980773c14d9632421 diff --git a/examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml similarity index 100% rename from examples/configs/grpo_math_qwen3_30ba3b_megatron_tp4_32k.yaml rename to examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml From 030ddc56e467f7faacd8618206500c5db475ad28 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 10:25:56 -0700 Subject: [PATCH 20/38] Remove duplicated nemo.tron.model code and use the new bumped submodule. Signed-off-by: Peter Jin --- .../models/policy/megatron_policy_worker.py | 15 +- nemo_rl/tron/__init__.py | 0 nemo_rl/tron/model.py | 180 ------------------ pyrefly.toml | 2 - 4 files changed, 4 insertions(+), 193 deletions(-) delete mode 100644 nemo_rl/tron/__init__.py delete mode 100644 nemo_rl/tron/model.py diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 947584e9f8..f2494c1137 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -124,7 +124,6 @@ get_megatron_checkpoint_dir, get_runtime_env_for_policy_worker, ) -from nemo_rl.tron.model import get_model_from_config_no_float32 TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) @@ -210,17 +209,14 @@ def re_enable_float32_expert_bias(model_module): model_post_init_fns.append(re_enable_float32_expert_bias) # Model, optimizer, and learning rate. - if policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None): - model_builder = get_model_from_config_no_float32 - else: - model_builder = get_model_from_config - model = model_builder( + model = get_model_from_config( cfg.model_config, cfg.ddp_config, use_torch_fsdp2=cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng_config.data_parallel_random_init, model_post_init_fns=model_post_init_fns, + wrap_cast_to_fp32=policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None), ) if load_optimizer: optimizer, scheduler = setup_optimizer( @@ -650,16 +646,13 @@ def __init__( ref_state = GlobalState() ref_state.cfg = ref_megatron_cfg - if self.cfg["megatron_cfg"].get("deferred_fp32_logits", None): - ref_model_builder = get_model_from_config_no_float32 - else: - ref_model_builder = get_model_from_config - reference_model = ref_model_builder( + reference_model = get_model_from_config( self.megatron_cfg.model_config, self.megatron_cfg.ddp_config, use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, + wrap_cast_to_fp32=self.cfg["megatron_cfg"].get("deferred_fp32_logits", None), ) print("Loading the Reference Model") if ( diff --git a/nemo_rl/tron/__init__.py b/nemo_rl/tron/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/nemo_rl/tron/model.py b/nemo_rl/tron/model.py deleted file mode 100644 index 01d6c3db0d..0000000000 --- a/nemo_rl/tron/model.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import torch -from megatron.core import parallel_state, tensor_parallel -from megatron.core.distributed import ( - DistributedDataParallel, - DistributedDataParallelConfig, - TorchFullyShardedDataParallel, -) -from megatron.core.enums import ModelType -from megatron.core.fp8_utils import is_float8tensor -from nemo.collections.llm.gpt.model.base import GPTConfig -from nemo.collections.llm.t5.model.t5 import T5Config - - -def get_model_from_config_no_float32( - model_config: GPTConfig | T5Config, - ddp_config: DistributedDataParallelConfig, - overlap_param_gather_with_optimizer_step: bool = False, - use_torch_fsdp2: bool = False, - wrap_with_ddp: bool = True, - data_parallel_random_init: bool = True, - model_post_init_fns: Optional[list[Callable]] = None, -): - # This method should only be called after `init_distributed()`. - # model_provider_func is equivalent to llm.gpt.GPTConfig.configure_model() - # model_type is inferred from the model_config class - - model_type = _get_model_type(model_config) - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None - ): - assert model_type != ModelType.encoder_and_decoder, ( - "Interleaved schedule not supported for model with both encoder and decoder" - ) - model = [] - for i in range(parallel_state.get_virtual_pipeline_model_parallel_world_size()): - parallel_state.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - this_model = model_config.configure_model( - tokenizer=None, - pre_process=pre_process, - post_process=post_process, - ) - this_model.model_type = model_type - model.append(this_model) - else: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - if model_type == ModelType.encoder_and_decoder: - assert isinstance(model_config, T5Config) - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - rank = parallel_state.get_pipeline_model_parallel_rank() - first_decoder_rank = ( - parallel_state.get_pipeline_model_parallel_decoder_start() - ) - world_size = parallel_state.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == first_decoder_rank - post_process = (rank == (first_decoder_rank - 1)) or ( - rank == (world_size - 1) - ) - model = model_config.configure_model( - tokenizer=None, - ) - else: - model = model_config.configure_model( - tokenizer=None, - pre_process=pre_process, - post_process=post_process, - ) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes( - param - ) - - # Print number of parameters. - if parallel_state.get_data_parallel_rank() == 0: - print( - " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_rank(), - sum( - [ - sum([p.nelement() for p in model_module.parameters()]) - for model_module in model - ] - ), - ), - flush=True, - ) - - # GPU allocation. - for model_module in model: - model_module.cuda(torch.cuda.current_device()) - - # Fp16 conversion. - if model_config.fp16 or model_config.bf16: - # TODO(pjin): okay to skip, as long as logits are float32-casted elsewhere. - # model = [Float16Module(model_config, model_module) for model_module in model] - pass - - if model_post_init_fns: - for model_module in model: - for post_init_fn in model_post_init_fns: - post_init_fn(model_module) - - # The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's - # Float8Tensor, which will write an unwanted value (amax calculated from the current fp8 - # param) to its amax_history. The following logic will correct the amax_history back. - for model_module in model: - for param in model_module.parameters(): - if is_float8tensor(param) and param._fp8_meta is not None: - fp8_meta = param._fp8_meta["scaling_fwd"] - fp8_meta_index = param._fp8_meta_index - if hasattr(param, "get_high_precision_init_val"): - fp8_meta.amax_history[0][fp8_meta_index].copy_( - param.get_high_precision_init_val().abs().max() - ) - else: - fp8_meta.amax_history[0][fp8_meta_index] = 0 - - if wrap_with_ddp: - if use_torch_fsdp2: - DP = TorchFullyShardedDataParallel - else: - DP = DistributedDataParallel - - model = [ - DP( - config=model_config, - ddp_config=ddp_config, - module=model_chunk, - # Turn off bucketing for model_chunk 2 onwards, since communication for these - # model chunks is overlapped with compute anyway. - disable_bucketing=(model_chunk_idx > 0) - or overlap_param_gather_with_optimizer_step, - ) - for (model_chunk_idx, model_chunk) in enumerate(model) - ] - - # Broadcast params from data parallel src rank to other data parallel ranks. - if data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() - return model - - -def _get_model_type(model_config: GPTConfig | T5Config) -> ModelType: - return ( - ModelType.encoder_and_decoder - if isinstance(model_config, T5Config) - else ModelType.encoder_or_decoder - ) diff --git a/pyrefly.toml b/pyrefly.toml index 1149c4086a..6442672edf 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -98,8 +98,6 @@ project-includes = [ "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", - "nemo_rl/tron/__init__.py", - "nemo_rl/tron/model.py", "nemo_rl/utils/__init__.py", "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", From 6f014ce636578c576287abc047165039400bfbe5 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 10:50:21 -0700 Subject: [PATCH 21/38] Remove leftover float32 cast. Ruff. Signed-off-by: Peter Jin --- nemo_rl/models/policy/megatron_policy_worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index f2494c1137..bf6238bab8 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -652,7 +652,9 @@ def __init__( use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, - wrap_cast_to_fp32=self.cfg["megatron_cfg"].get("deferred_fp32_logits", None), + wrap_cast_to_fp32=self.cfg["megatron_cfg"].get( + "deferred_fp32_logits", None + ), ) print("Loading the Reference Model") if ( @@ -1106,7 +1108,7 @@ def collection_fn(output_tensor): ) else: token_logprobs = from_parallel_logits_to_logprobs( - output_tensor.to(torch.float32), + output_tensor, target=unpacked_input_ids, vocab_start_index=tp_rank * output_tensor.shape[-1], vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], From 9e5dcd7f3eb05a2f2d6efcf520f7c0ecadbd1788 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 11:09:04 -0700 Subject: [PATCH 22/38] Check for float32 logprobs. Signed-off-by: Peter Jin --- .../models/policy/test_megatron_worker.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 20c31324e0..fd2c499274 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import os import tempfile @@ -40,6 +41,8 @@ def create_megatron_test_config( generation_backend: str = "megatron", sequence_parallel: bool = False, converter_type: str = "LlamaForCausalLM", + logprob_chunk_size: Optional[int] = None, + deferred_fp32_logits: Optional[bool] = None, ) -> PolicyConfig: """Create a test config for Megatron policy worker.""" return { @@ -50,6 +53,7 @@ def create_megatron_test_config( "train_micro_batch_size": 2, "learning_rate": 5e-6, "logprob_batch_size": 2, + "logprob_chunk_size": logprob_chunk_size, "precision": precision, "generation": { "backend": generation_backend, @@ -95,6 +99,7 @@ def create_megatron_test_config( "moe_router_load_balancing_type": "none", "moe_router_bias_update_rate": 0.0, "apply_rope_fusion": True, + "deferred_fp32_logits": deferred_fp32_logits, "optimizer": { "optimizer": "adam", "lr": 5.0e-6, @@ -562,9 +567,11 @@ def logprob_setup(request): """Setup and teardown specifically for logprob tests.""" # Parse parameters: (num_gpus, tp, pp, model_fixture_name) if hasattr(request, "param") and request.param is not None: - num_gpus, tp, pp, model_fixture_name = request.param + num_gpus, tp, pp, logprob_chunk_size, deferred_fp32_logits, model_fixture_name = request.param else: - num_gpus, tp, pp, model_fixture_name = 2, 1, 1, "tiny_llama_model_path" + num_gpus, tp, pp, logprob_chunk_size, deferred_fp32_logits, model_fixture_name = ( + 2, 1, 1, None, None, "tiny_llama_model_path" + ) # Get the actual model path from the requested fixture model_name = request.getfixturevalue(model_fixture_name) @@ -599,6 +606,8 @@ def logprob_setup(request): tp=tp, pp=pp, converter_type=converter_type, + logprob_chunk_size=logprob_chunk_size, + deferred_fp32_logits=deferred_fp32_logits, ) tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config( @@ -647,11 +656,19 @@ def logprob_setup(request): @pytest.mark.parametrize( "logprob_setup", [ - # (num_gpus, tp, pp, model_fixture_name) - (2, 1, 1, "tiny_llama_model_path"), - (2, 2, 1, "tiny_llama_model_path"), - (2, 1, 1, "tiny_qwen2_model_path"), - (2, 2, 1, "tiny_qwen2_model_path"), + # (num_gpus, tp, pp, chunk sz, defer fp32, model_fixture_name) + (2, 1, 1, None, None, "tiny_llama_model_path"), + (2, 2, 1, None, None, "tiny_llama_model_path"), + (2, 1, 1, None, None, "tiny_qwen2_model_path"), + (2, 2, 1, None, None, "tiny_qwen2_model_path"), + (2, 1, 1, None, True, "tiny_llama_model_path"), + (2, 2, 1, None, True, "tiny_llama_model_path"), + (2, 1, 1, None, True, "tiny_qwen2_model_path"), + (2, 2, 1, None, True, "tiny_qwen2_model_path"), + (2, 1, 1, 16, True, "tiny_llama_model_path"), + (2, 2, 1, 16, True, "tiny_llama_model_path"), + (2, 1, 1, 16, True, "tiny_qwen2_model_path"), + (2, 2, 1, 16, True, "tiny_qwen2_model_path"), ], indirect=True, ids=["2gpu_dp2_llama", "2gpu_tp2_llama", "2gpu_dp2_qwen2", "2gpu_tp2_qwen2"], @@ -671,6 +688,7 @@ def test_megatron_policy_logprobs(logprob_setup): # Basic validation assert isinstance(policy_logprobs, torch.Tensor), "Logprobs should be a tensor" + assert policy_logprobs.dtype == torch.float32 assert policy_logprobs.shape == data.get("input_ids").shape, ( f"Logprobs shape {policy_logprobs.shape} should match input shape {data.get('input_ids').shape}" ) From b181eb0fa2161767da6dce89cb0b3e4bce3e58d9 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 11:13:12 -0700 Subject: [PATCH 23/38] Lint. Signed-off-by: Peter Jin --- .../models/policy/test_megatron_worker.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index fd2c499274..3d3d6c1af0 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import os import tempfile +from typing import Optional import pytest import torch @@ -567,11 +567,23 @@ def logprob_setup(request): """Setup and teardown specifically for logprob tests.""" # Parse parameters: (num_gpus, tp, pp, model_fixture_name) if hasattr(request, "param") and request.param is not None: - num_gpus, tp, pp, logprob_chunk_size, deferred_fp32_logits, model_fixture_name = request.param + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + deferred_fp32_logits, + model_fixture_name, + ) = request.param else: - num_gpus, tp, pp, logprob_chunk_size, deferred_fp32_logits, model_fixture_name = ( - 2, 1, 1, None, None, "tiny_llama_model_path" - ) + ( + num_gpus, + tp, + pp, + logprob_chunk_size, + deferred_fp32_logits, + model_fixture_name, + ) = (2, 1, 1, None, None, "tiny_llama_model_path") # Get the actual model path from the requested fixture model_name = request.getfixturevalue(model_fixture_name) From 3cb90c9ec758bb5803755510cff767aba8fe0999 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 11:40:43 -0700 Subject: [PATCH 24/38] Add example config (TODO: functional test for this config). Signed-off-by: Peter Jin --- ...po-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 89 ++++++++++++++++--- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index d9bc9ef1cc..a678a3c9e8 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -1,18 +1,28 @@ -# GRPO Algorithm Configuration -defaults: "grpo_math_1B_megatron.yaml" - checkpointing: - enabled: True - save_period: 4 - keep_top_k: 100 + enabled: False + save_period: 10 + keep_top_k: 1 grpo: + normalize_rewards: True + use_leave_one_out_baseline: True + max_num_steps: 1000000 num_prompts_per_step: 64 num_generations_per_prompt: 16 - val_period: 4 - val_at_start: True - # max_val_samples: 256 - # val_batch_size: 256 + val_period: 10 + val_at_start: False + max_val_samples: 256 + val_batch_size: 256 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: False + use_importance_sampling_correction: False + token_level_loss: true + ratio_clip_c: null policy: model_name: "Qwen/Qwen3-30B-A3B" @@ -28,11 +38,17 @@ policy: logprob_chunk_size: 2048 dtensor_cfg: - enabled: false + enabled: False + + dynamic_batching: + enabled: False sequence_packing: enabled: False + max_grad_norm: 1.0 + make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} + optimizer: null # remove default FSDP optimizer scheduler: null # remove default FSDP scheduler @@ -48,6 +64,11 @@ policy: expert_model_parallel_size: 8 sequence_parallel: True pipeline_dtype: ${policy.precision} + freeze_moe_router: True + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + apply_rope_fusion: True activation_checkpointing: True deferred_fp32_logits: True @@ -59,10 +80,19 @@ policy: min_lr: 5.0e-8 weight_decay: 0.0 # weight_decay: 0.01 - bf16: true - fp16: false + bf16: True + fp16: False params_dtype: "float32" + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + scheduler: start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} @@ -73,6 +103,14 @@ policy: # lr_warmup_init: 3.0e-8 lr_warmup_iters: 4 lr_warmup_init: 5.0e-8 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + average_in_collective: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" env_vars: PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" @@ -88,7 +126,8 @@ policy: vllm_cfg: tensor_parallel_size: 4 gpu_memory_utilization: 0.6 - enforce_eager: false + # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 + enforce_eager: True max_model_len: ${policy.max_total_sequence_length} data: @@ -98,6 +137,28 @@ data: system_prompt_file: null # dataset_name: null +env: + math: + num_workers: 8 + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "sj_megatron_1B" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "sj_megatron_1B" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + cluster: gpus_per_node: 8 num_nodes: 4 From dde98d93558b5b8a54becc8157a35bfce4fc6775 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 12:45:50 -0700 Subject: [PATCH 25/38] Add 32K max context Qwen3 30B MoE test run. Signed-off-by: Peter Jin --- ...po-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 63 +++++++++---------- ...grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh | 39 ++++++++++++ 2 files changed, 70 insertions(+), 32 deletions(-) create mode 100644 tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index a678a3c9e8..68a1a43976 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -1,18 +1,23 @@ checkpointing: - enabled: False - save_period: 10 + enabled: True + checkpoint_dir: results/grpo-math-qwen3-30ba3b-megatron-tp4-32k + save_period: 3 keep_top_k: 1 + metric_name: val_reward + higher_is_better: True + checkpoint_must_save_by: null grpo: normalize_rewards: True use_leave_one_out_baseline: True - max_num_steps: 1000000 + max_num_steps: 3 num_prompts_per_step: 64 num_generations_per_prompt: 16 - val_period: 10 + val_period: 3 val_at_start: False max_val_samples: 256 val_batch_size: 256 + seed: 42 loss_fn: reference_policy_kl_penalty: 0.01 @@ -21,7 +26,7 @@ loss_fn: # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: False use_importance_sampling_correction: False - token_level_loss: true + token_level_loss: True ratio_clip_c: null policy: @@ -54,7 +59,7 @@ policy: scheduler: null # remove default FSDP scheduler megatron_cfg: - enabled: true + enabled: True empty_unused_memory_level: 1 converter_type: "LlamaForCausalLM" tensor_model_parallel_size: 4 @@ -74,12 +79,9 @@ policy: optimizer: optimizer: "adam" - # lr: 3.0e-7 - # min_lr: 3.0e-8 lr: 5.0e-7 min_lr: 5.0e-8 weight_decay: 0.0 - # weight_decay: 0.01 bf16: True fp16: False params_dtype: "float32" @@ -88,8 +90,8 @@ policy: adam_beta2: 0.999 adam_eps: 1e-8 - use_distributed_optimizer: true - use_precision_aware_optimizer: true + use_distributed_optimizer: True + use_precision_aware_optimizer: True clip_grad: ${policy.max_grad_norm} @@ -99,17 +101,15 @@ policy: weight_decay_incr_style: "constant" lr_decay_style: "constant" lr_decay_iters: null - # lr_warmup_iters: 13 - # lr_warmup_init: 3.0e-8 - lr_warmup_iters: 4 + lr_warmup_iters: 2 lr_warmup_init: 5.0e-8 distributed_data_parallel_config: - grad_reduce_in_fp32: false - overlap_grad_reduce: true - overlap_param_gather: true - average_in_collective: true - use_custom_fsdp: false + grad_reduce_in_fp32: False + overlap_grad_reduce: True + overlap_param_gather: True + average_in_collective: True + use_custom_fsdp: False data_parallel_sharding_strategy: "optim_grads_params" env_vars: @@ -124,37 +124,36 @@ policy: stop_token_ids: null stop_strings: null vllm_cfg: + async_engine: False + precision: ${policy.precision} tensor_parallel_size: 4 gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 enforce_eager: True - max_model_len: ${policy.max_total_sequence_length} data: + dataset_name: "OpenMathInstruct-2" + shuffle: true max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len prompt_file: "examples/prompts/cot.txt" - # prompt_file: null system_prompt_file: null - # dataset_name: null env: math: num_workers: 8 logger: - log_dir: "logs" # Base directory for all logs + log_dir: logs/grpo-math-qwen3-30ba3b-megatron-tp4-32k num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal - wandb_enabled: false - tensorboard_enabled: false - mlflow_enabled: false # Disable MLflow logging - monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb_enabled: True + tensorboard_enabled: True + mlflow_enabled: False # Disable MLflow logging + monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: - project: "grpo-dev" - name: "sj_megatron_1B" + project: nemo-rl + name: "grpo-math-qwen3-30ba3b-megatron-tp4-32k" tensorboard: {} - mlflow: - experiment_name: "grpo-dev" - run_name: "sj_megatron_1B" gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) diff --git a/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh b/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh new file mode 100644 index 0000000000..993d541871 --- /dev/null +++ b/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh @@ -0,0 +1,39 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=4 +STEPS_PER_RUN=3 +MAX_STEPS=3 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["$MAX_STEPS"] < 1.1' +fi From 651fcdb37f623428572c3b70075016d583d72fc8 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 14:07:09 -0700 Subject: [PATCH 26/38] Fix deferred fp32 config. Signed-off-by: Peter Jin --- nemo_rl/models/policy/megatron_policy_worker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index bf6238bab8..2034c98d19 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -216,7 +216,9 @@ def re_enable_float32_expert_bias(model_module): overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng_config.data_parallel_random_init, model_post_init_fns=model_post_init_fns, - wrap_cast_to_fp32=policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None), + wrap_cast_to_fp32=( + not policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None) + ), ) if load_optimizer: optimizer, scheduler = setup_optimizer( @@ -652,8 +654,8 @@ def __init__( use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, - wrap_cast_to_fp32=self.cfg["megatron_cfg"].get( - "deferred_fp32_logits", None + wrap_cast_to_fp32=( + not self.cfg["megatron_cfg"].get("deferred_fp32_logits", None) ), ) print("Loading the Reference Model") From 65bbd9be085b2488a5d8471269150f3c8c64817d Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 14:23:03 -0700 Subject: [PATCH 27/38] Rename deferred_fp32_logits => defer_fp32_logits. Signed-off-by: Peter Jin --- .../llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 2 +- nemo_rl/models/policy/__init__.py | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 4 ++-- tests/unit/models/policy/test_megatron_worker.py | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index 68a1a43976..19d5c981dd 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -75,7 +75,7 @@ policy: moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo apply_rope_fusion: True activation_checkpointing: True - deferred_fp32_logits: True + defer_fp32_logits: True optimizer: optimizer: "adam" diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 507f3f0e83..872fff35ff 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -93,7 +93,7 @@ class MegatronConfig(TypedDict): freeze_moe_router: bool expert_tensor_parallel_size: int expert_model_parallel_size: int - deferred_fp32_logits: NotRequired[bool] + defer_fp32_logits: NotRequired[bool] optimizer: NotRequired[MegatronOptimizerConfig] scheduler: NotRequired[MegatronSchedulerConfig] diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 26c4797598..b842769488 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -217,7 +217,7 @@ def re_enable_float32_expert_bias(model_module): data_parallel_random_init=cfg.rng_config.data_parallel_random_init, model_post_init_fns=model_post_init_fns, wrap_cast_to_fp32=( - not policy_cfg["megatron_cfg"].get("deferred_fp32_logits", None) + not policy_cfg["megatron_cfg"].get("defer_fp32_logits", None) ), ) if load_optimizer: @@ -665,7 +665,7 @@ def __init__( overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, wrap_cast_to_fp32=( - not self.cfg["megatron_cfg"].get("deferred_fp32_logits", None) + not self.cfg["megatron_cfg"].get("defer_fp32_logits", None) ), ) print("Loading the Reference Model") diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 3d3d6c1af0..5b403108a6 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -42,7 +42,7 @@ def create_megatron_test_config( sequence_parallel: bool = False, converter_type: str = "LlamaForCausalLM", logprob_chunk_size: Optional[int] = None, - deferred_fp32_logits: Optional[bool] = None, + defer_fp32_logits: Optional[bool] = None, ) -> PolicyConfig: """Create a test config for Megatron policy worker.""" return { @@ -99,7 +99,7 @@ def create_megatron_test_config( "moe_router_load_balancing_type": "none", "moe_router_bias_update_rate": 0.0, "apply_rope_fusion": True, - "deferred_fp32_logits": deferred_fp32_logits, + "defer_fp32_logits": defer_fp32_logits, "optimizer": { "optimizer": "adam", "lr": 5.0e-6, @@ -572,7 +572,7 @@ def logprob_setup(request): tp, pp, logprob_chunk_size, - deferred_fp32_logits, + defer_fp32_logits, model_fixture_name, ) = request.param else: @@ -581,7 +581,7 @@ def logprob_setup(request): tp, pp, logprob_chunk_size, - deferred_fp32_logits, + defer_fp32_logits, model_fixture_name, ) = (2, 1, 1, None, None, "tiny_llama_model_path") @@ -619,7 +619,7 @@ def logprob_setup(request): pp=pp, converter_type=converter_type, logprob_chunk_size=logprob_chunk_size, - deferred_fp32_logits=deferred_fp32_logits, + defer_fp32_logits=defer_fp32_logits, ) tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config( From edd00efbd4e8680a0ccce91eafcda79c0c40e463 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 15:29:28 -0700 Subject: [PATCH 28/38] chmod +x Signed-off-by: Peter Jin --- tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh diff --git a/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh b/tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh old mode 100644 new mode 100755 From a68dfa2957ba03da0819c2896a4b7434d00cbe47 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 17:26:58 -0700 Subject: [PATCH 29/38] Missing config. Signed-off-by: Peter Jin --- .../llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index 19d5c981dd..7249d4ad06 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -127,10 +127,16 @@ policy: async_engine: False precision: ${policy.precision} tensor_parallel_size: 4 + pipeline_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 enforce_eager: True + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null data: dataset_name: "OpenMathInstruct-2" From a0044adbfe9197a3fe6a78aae83026e03b392f4f Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 17:47:44 -0700 Subject: [PATCH 30/38] More missing config. Signed-off-by: Peter Jin --- .../recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index 7249d4ad06..ed376c5037 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -69,6 +69,8 @@ policy: expert_model_parallel_size: 8 sequence_parallel: True pipeline_dtype: ${policy.precision} + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null freeze_moe_router: True moe_router_dtype: "fp64" moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo From 8de985dbc1440eca8171cb6fb60fa4f6a96b62df Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 20:07:09 -0700 Subject: [PATCH 31/38] Using updated NeMo branch. Signed-off-by: Peter Jin --- .gitmodules | 2 +- 3rdparty/NeMo-workspace/NeMo | 2 +- nemo_rl/models/policy/megatron_policy_worker.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.gitmodules b/.gitmodules index 205497f6df..d6e8586781 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/NeMo"] path = 3rdparty/NeMo-workspace/NeMo url = https://github.com/NVIDIA/NeMo.git - branch = https://github.com/NVIDIA/NeMo/tree/pjin/ashors/rl-qwen3-export + branch = pjin/ashors/rl-qwen3-export shallow = true [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 44d9aea478..5c42641e34 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 44d9aea4782a186a1ebbbe98352fefa1a8f92735 +Subproject commit 5c42641e344a487c7ca5b253a7483f0af8ef40e6 diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index b842769488..3291f4b4c7 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -216,7 +216,7 @@ def re_enable_float32_expert_bias(model_module): overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng_config.data_parallel_random_init, model_post_init_fns=model_post_init_fns, - wrap_cast_to_fp32=( + wrap_cast_model_output_to_fp32=( not policy_cfg["megatron_cfg"].get("defer_fp32_logits", None) ), ) @@ -664,7 +664,7 @@ def __init__( use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, - wrap_cast_to_fp32=( + wrap_cast_model_output_to_fp32=( not self.cfg["megatron_cfg"].get("defer_fp32_logits", None) ), ) From c5c83baf0e112757d20fe1174a3ef93926702472 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Sat, 9 Aug 2025 14:23:43 -0700 Subject: [PATCH 32/38] More missing config. Signed-off-by: Peter Jin --- .../recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml index ed376c5037..c4a0848fe1 100644 --- a/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml +++ b/examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml @@ -13,6 +13,7 @@ grpo: max_num_steps: 3 num_prompts_per_step: 64 num_generations_per_prompt: 16 + max_rollout_turns: 1 val_period: 3 val_at_start: False max_val_samples: 256 From da2f305ecc6463e0448f79448b8675e83580c739 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 12 Aug 2025 11:05:55 -0700 Subject: [PATCH 33/38] Lint and minor refactor. Signed-off-by: Peter Jin --- .../models/policy/dtensor_policy_worker.py | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 72391713f5..78d12d9169 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1051,21 +1051,42 @@ def get_logprobs( input_ids, chunk_size=logprob_chunk_size, ) + elif logprob_chunk_size is not None: + logits_seq_len = int(logits.shape[1]) + num_chunks = ( + logits_seq_len + logprob_chunk_size - 1 + ) // logprob_chunk_size + chunked_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * logprob_chunk_size + chunk_end = min( + logits_seq_len, (chunk_idx + 1) * logprob_chunk_size + ) + logits = logits[:, chunk_start:chunk_end, :].to( + torch.float32 + ) + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) + chunked_log_probs.append(log_probs) + log_probs = torch.cat(chunked_log_probs, dim=1) + del chunked_log_probs else: - # Extract logprobs for each token in the sequence by gathering the logprob - # corresponding to the next token at each position - # Input shapes: - # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position - # token_ids: [batch_size, sequence_length] - actual tokens - # Output shape: [batch_size, sequence_length] - logprob of each token given previous - # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length logits = logits.to(torch.float32) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - next_tokens = input_ids[:, 1:] - log_probs = log_probs[:, :-1] - token_logprobs = log_probs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + del log_probs del outputs, logits From cd5b02a623e512ca5c1b6a6b73e61eb99ba7f5f7 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 12 Aug 2025 11:09:05 -0700 Subject: [PATCH 34/38] Fix. Signed-off-by: Peter Jin --- nemo_rl/models/policy/dtensor_policy_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 78d12d9169..958afd36ba 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1062,11 +1062,11 @@ def get_logprobs( chunk_end = min( logits_seq_len, (chunk_idx + 1) * logprob_chunk_size ) - logits = logits[:, chunk_start:chunk_end, :].to( + chunk_logits = logits[:, chunk_start:chunk_end, :].to( torch.float32 ) log_probs = torch.nn.functional.log_softmax( - logits, dim=-1 + chunk_logits, dim=-1 ) chunked_log_probs.append(log_probs) log_probs = torch.cat(chunked_log_probs, dim=1) From 81fb8e19fa46d13b7dd13d7b78af2ac6b1cfb7ca Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 12 Aug 2025 11:23:07 -0700 Subject: [PATCH 35/38] Unnecessary clone. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index de9cedaaba..4ec0664a7f 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -181,11 +181,10 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) - log_softmax_output = _compute_distributed_log_softmax( + log_probs = _compute_distributed_log_softmax( logits, group=tp_group, ) - log_probs = log_softmax_output.clone() log_probs = torch.gather( log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) From ef9d3d5be9a2f0fcb4fcf1d56a238baa5fc8fc4a Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 12 Aug 2025 13:38:09 -0700 Subject: [PATCH 36/38] Remove clone + exp_ with just exp. Signed-off-by: Peter Jin --- nemo_rl/distributed/model_utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 4ec0664a7f..29cc5eb6b7 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -79,11 +79,8 @@ def forward( # pyrefly: ignore[bad-override] Always ignore torch.autograd.Func vocab_parallel_logits = vocab_parallel_logits.to(dtype=torch.float32) - log_softmax_output = _compute_distributed_log_softmax( - vocab_parallel_logits, group=group - ) - log_probs = log_softmax_output.clone() - softmax_output = log_softmax_output.exp_() + log_probs = _compute_distributed_log_softmax(vocab_parallel_logits, group=group) + softmax_output = log_probs.exp() log_probs = torch.gather(log_probs, -1, masked_target.unsqueeze(-1)).squeeze(-1) log_probs[target_mask] = 0.0 @@ -232,12 +229,11 @@ def backward( logits = vocab_parallel_logits[:, chunk_start:chunk_end, :] logits = logits.to(dtype=torch.float32) - log_softmax_output = _compute_distributed_log_softmax( + softmax_output = _compute_distributed_log_softmax( logits, group=tp_group, ) - log_probs = log_softmax_output.clone() - softmax_output = log_softmax_output.exp_() + softmax_output = softmax_output.exp() # 1 if it's the chosen log prob, 0 otherwise is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( From 3d38161cb7cf5c0ebfa6c2f93db0cc13df9bbab9 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 13 Aug 2025 12:18:30 -0700 Subject: [PATCH 37/38] Set HF_HUB_OFFLINE=1 for github CI. Signed-off-by: Peter Jin --- .github/actions/test-template/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index 3e16304fcf..98a0fbcfcc 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -162,6 +162,7 @@ runs: --shm-size=64g \ --env TRANSFORMERS_OFFLINE=0 \ --env HYDRA_FULL_ERROR=1 \ + --env HF_HUB_OFFLINE=1 \ --env HF_HOME=/home/TestData/nemo-rl/hf_home \ --env HF_DATASETS_CACHE=/home/TestData/nemo-rl/hf_datasets_cache \ --env NEMO_RL_REPO_DIR=/opt/nemo-rl \ From 0f0de7d3c68c124259749f16ce239712ecd51941 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 13 Aug 2025 14:26:16 -0700 Subject: [PATCH 38/38] Fix test. Signed-off-by: Peter Jin --- tests/unit/models/policy/test_megatron_worker.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index d14f7a9382..cd287a6370 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -675,7 +675,20 @@ def logprob_setup(request): (2, 2, 1, 16, True, "tiny_qwen2_model_path"), ], indirect=True, - ids=["2gpu_dp2_llama", "2gpu_tp2_llama", "2gpu_dp2_qwen2", "2gpu_tp2_qwen2"], + ids=[ + "2gpu_dp2_llama", + "2gpu_tp2_llama", + "2gpu_dp2_qwen2", + "2gpu_tp2_qwen2", + "2gpu_dp2_deferfp32_llama", + "2gpu_tp2_deferfp32_llama", + "2gpu_dp2_deferfp32_qwen2", + "2gpu_tp2_deferfp32_qwen2", + "2gpu_dp2_chunked_deferfp32_llama", + "2gpu_tp2_chunked_deferfp32_llama", + "2gpu_dp2_chunked_deferfp32_qwen2", + "2gpu_tp2_chunked_deferfp32_qwen2", + ], ) def test_megatron_policy_logprobs(logprob_setup): """Test Megatron policy logprob computation."""