diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 7a8a651a54..600dbfc41c 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -116,7 +116,7 @@ policy: weight_decay_incr_style: "constant" lr_decay_style: "constant" lr_decay_iters: null - lr_warmup_iters: 50 + lr_warmup_iters: 13 lr_warmup_init: 5.0e-7 distributed_data_parallel_config: diff --git a/examples/configs/grpo_math_70B_megatron.yaml b/examples/configs/grpo_math_70B_megatron.yaml index a7ba2c8a52..1aaad35659 100644 --- a/examples/configs/grpo_math_70B_megatron.yaml +++ b/examples/configs/grpo_math_70B_megatron.yaml @@ -49,7 +49,7 @@ policy: weight_decay_incr_style: "constant" lr_decay_style: "constant" lr_decay_iters: null - lr_warmup_iters: 50 + lr_warmup_iters: 13 lr_warmup_init: 3.0e-8 generation: diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index 915babbf5c..8ebd93e7a1 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -52,7 +52,7 @@ policy: weight_decay_incr_style: "constant" lr_decay_style: "constant" lr_decay_iters: null - lr_warmup_iters: 50 + lr_warmup_iters: 13 lr_warmup_init: 3.0e-8 env_vars: diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 7daa6de019..89ab2ec973 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -857,13 +857,8 @@ def train( num_zeros_in_grad ) - # Update learning rate. if update_successful: - increment = total_dataset_size.item() - self.scheduler.step(increment=increment) skipped_iter = 0 - curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) - curr_wd = self.scheduler.get_wd() else: skipped_iter = 1 @@ -880,6 +875,8 @@ def train( for k in x.keys(): loss_metrics[k] = x[k] / num_global_batches gb_loss_metrics.append(loss_metrics) + curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) + curr_wd = self.scheduler.get_wd() loss_metrics["lr"] = curr_lr loss_metrics["wd"] = curr_wd loss_metrics["grad_norm"] = grad_norm @@ -905,6 +902,10 @@ def train( all_mb_metrics.extend(gb_loss_metrics) losses.append(torch.tensor(mb_losses).sum().item()) + if not eval_mode: + # take one LR step every rollout batch + self.scheduler.step(increment=1) + # Aggregate metrics across all microbatches mb_metrics = defaultdict(list) for m in all_mb_metrics: diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index a23c1b5559..35508a5e12 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -777,8 +777,8 @@ def test_megatron_reference_policy_functionality(): ) config = create_megatron_test_config() - config["megatron_cfg"]["optimizer"]["lr"] = 1e-3 # Increase from 5e-6 to 1e-3 - config["megatron_cfg"]["optimizer"]["min_lr"] = 1e-4 # Increase min_lr as well + config["megatron_cfg"]["optimizer"]["lr"] = 1e-2 # Increase from 5e-6 to 1e-2 + config["megatron_cfg"]["optimizer"]["min_lr"] = 1e-3 # Increase min_lr as well tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config(config["generation"], tokenizer)