Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 5.0e-7

distributed_data_parallel_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_70B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

generation:
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_qwen30ba3b_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

env_vars:
Expand Down
10 changes: 6 additions & 4 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,11 +859,7 @@ def train(

# Update learning rate.
Comment thread
ashors1 marked this conversation as resolved.
Outdated
if update_successful:
increment = total_dataset_size.item()
self.scheduler.step(increment=increment)
skipped_iter = 0
curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0])
curr_wd = self.scheduler.get_wd()
else:
skipped_iter = 1

Expand All @@ -880,6 +876,8 @@ def train(
for k in x.keys():
loss_metrics[k] = x[k] / num_global_batches
gb_loss_metrics.append(loss_metrics)
curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0])
curr_wd = self.scheduler.get_wd()
loss_metrics["lr"] = curr_lr
loss_metrics["wd"] = curr_wd
loss_metrics["grad_norm"] = grad_norm
Expand All @@ -905,6 +903,10 @@ def train(
all_mb_metrics.extend(gb_loss_metrics)
losses.append(torch.tensor(mb_losses).sum().item())

if not eval_mode:
# take one LR step every rollout batch
self.scheduler.step(increment=1)

# Aggregate metrics across all microbatches
mb_metrics = defaultdict(list)
for m in all_mb_metrics:
Expand Down
Loading