From fe2be4b5729349815237a65da70bf7789d8cfbf8 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 27 Jun 2025 09:12:30 -0700 Subject: [PATCH] fix: correct mcore dtype + assertion on activation_func Signed-off-by: Terry Kong --- .../models/policy/megatron_policy_worker.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index e0bd4373be..3b6ce13e30 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -459,15 +459,27 @@ def __init__( ) model_cfg.bf16 = self.dtype == torch.bfloat16 model_cfg.fp16 = self.dtype == torch.float16 - model_cfg.params_dtype = dtype_map[ - self.cfg["megatron_cfg"]["optimizer"]["params_dtype"] - ] # FP32 for amp + if model_cfg.fp16: + assert not model_cfg.bf16, "fp16 and bf16 cannot be used together" + model_cfg.params_dtype = torch.float16 + elif model_cfg.bf16: + assert not model_cfg.fp16, "fp16 and bf16 cannot be used together" + model_cfg.params_dtype = torch.bfloat16 + else: + model_cfg.params_dtype = torch.float32 model_cfg.pipeline_dtype = dtype_map[self.cfg["megatron_cfg"]["pipeline_dtype"]] model_cfg.parallel_output = True if self.cfg["megatron_cfg"]["activation_checkpointing"]: model_cfg.activations_checkpoint_granularity = "full" model_cfg.activations_checkpoint_method = "uniform" model_cfg.activations_checkpoint_num_layers = 1 + if not model_cfg.gated_linear_unit: + assert model_cfg.activation_func is not None, ( + "activation_func must be set if not using gated_linear_unit. This likely " + "indicates an issue in configuration conversion (e.g. activation func was " + "a lambda and couldn't be serialized). This is based on this check " + "https://github.com/NVIDIA/Megatron-LM/blob/1ab876ddc4c1893c76f26d775226a8d1dcdfb3d2/megatron/core/transformer/mlp.py#L174." + ) checkpoint_config = CheckpointConfig( save_interval=100,