Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 5.0e-7

distributed_data_parallel_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_70B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

generation:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_qwen30ba3b_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

env_vars:
Expand Down
11 changes: 6 additions & 5 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,13 +857,8 @@ def train(
num_zeros_in_grad
)

# Update learning rate.
if update_successful:
increment = total_dataset_size.item()
self.scheduler.step(increment=increment)
skipped_iter = 0
curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0])
curr_wd = self.scheduler.get_wd()
else:
skipped_iter = 1

Expand All @@ -880,6 +875,8 @@ def train(
for k in x.keys():
loss_metrics[k] = x[k] / num_global_batches
gb_loss_metrics.append(loss_metrics)
curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0])
curr_wd = self.scheduler.get_wd()
loss_metrics["lr"] = curr_lr
loss_metrics["wd"] = curr_wd
loss_metrics["grad_norm"] = grad_norm
Expand All @@ -905,6 +902,10 @@ def train(
all_mb_metrics.extend(gb_loss_metrics)
losses.append(torch.tensor(mb_losses).sum().item())

if not eval_mode:
# take one LR step every rollout batch
self.scheduler.step(increment=1)

# Aggregate metrics across all microbatches
mb_metrics = defaultdict(list)
for m in all_mb_metrics:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ def test_megatron_reference_policy_functionality():
)

config = create_megatron_test_config()
config["megatron_cfg"]["optimizer"]["lr"] = 1e-3 # Increase from 5e-6 to 1e-3
config["megatron_cfg"]["optimizer"]["min_lr"] = 1e-4 # Increase min_lr as well
config["megatron_cfg"]["optimizer"]["lr"] = 1e-2 # Increase from 5e-6 to 1e-2
config["megatron_cfg"]["optimizer"]["min_lr"] = 1e-3 # Increase min_lr as well

tokenizer = get_tokenizer(config["tokenizer"])
config["generation"] = configure_generation_config(config["generation"], tokenizer)
Expand Down
Loading