From dd1e45c668fa636e38fac2753f878adcb550efbe Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Fri, 9 May 2025 11:03:54 +0000 Subject: [PATCH 1/2] fix: model won't be updated on first training step - Fix LR scheduler step timing to properly record and apply learning rate - Correct warmup scheduler implementation to maintain constant rate after warmup - Increase learning rate in test script for better checkpoint validation --- tests/e2e/ppo_trainer/run_function_reward.sh | 3 ++- verl/utils/torch_functional.py | 4 +++- verl/workers/fsdp_workers.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index 0c842e86ffe..eb6c24a152f 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -73,7 +73,8 @@ python3 -m verl.trainer.main_ppo \ data.max_prompt_length="${MAX_PROMPT_LEN}" \ data.max_response_length="${MAX_RESPONSE_LEN}" \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ + # We only run one step, make lr a bit higher so that our model weight will be updated more, used for checkpoint testing + actor_rollout_ref.actor.optim.lr=1e-4 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 2d42b25d97a..167e46d0f13 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -436,7 +436,9 @@ def get_constant_schedule_with_warmup( last_epoch: int = -1, ): def lr_lambda(current_step): - return min(1, float(current_step) / float(max(1, num_warmup_steps))) + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 92850a3bbf4..e0f55544114 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -567,9 +567,9 @@ def update_actor(self, data: DataProto): metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr + self.actor_lr_scheduler.step() # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) From e2375cc0a36c5cd513c2587533ffcab5ff773f12 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Fri, 9 May 2025 14:29:23 +0000 Subject: [PATCH 2/2] just modify assert_close tolerance don't change e2e lr --- scripts/model_merger.py | 2 +- tests/e2e/ppo_trainer/run_function_reward.sh | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 590f4508c04..213c1a3e07f 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -113,7 +113,7 @@ def test_fsdp_state_dict( collected_dtype = collected_state_dict[key].dtype assert original_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {original_dtype} vs collected {collected_dtype}" - torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-4, rtol=1e-4) + torch.testing.assert_close(original_state_dict[key], collected_state_dict[key], atol=1e-6, rtol=1e-6) print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") return True diff --git a/tests/e2e/ppo_trainer/run_function_reward.sh b/tests/e2e/ppo_trainer/run_function_reward.sh index eb6c24a152f..0c842e86ffe 100644 --- a/tests/e2e/ppo_trainer/run_function_reward.sh +++ b/tests/e2e/ppo_trainer/run_function_reward.sh @@ -73,8 +73,7 @@ python3 -m verl.trainer.main_ppo \ data.max_prompt_length="${MAX_PROMPT_LEN}" \ data.max_response_length="${MAX_RESPONSE_LEN}" \ actor_rollout_ref.model.path="${MODEL_PATH}" \ - # We only run one step, make lr a bit higher so that our model weight will be updated more, used for checkpoint testing - actor_rollout_ref.actor.optim.lr=1e-4 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \