diff --git a/rllm/trainer/verl/agent_sdk_trainer.py b/rllm/trainer/verl/agent_sdk_trainer.py index 0d59b7c19..7880f6f76 100644 --- a/rllm/trainer/verl/agent_sdk_trainer.py +++ b/rllm/trainer/verl/agent_sdk_trainer.py @@ -175,6 +175,7 @@ def fit_agent(self): self.global_steps = 0 self._load_checkpoint() + self.checkpoint_manager.update_weights(self.global_steps) start_time = time.time() if self.config.trainer.get("val_before_train", True): @@ -213,6 +214,7 @@ def fit_agent(self): with marked_timer("step", timing_raw): # generate trajectories final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw) + self.checkpoint_manager.sleep_replicas() # need to repeat to make shape match repeat_counts = final_gen_batch_output.meta_info["repeat_counts"] @@ -473,6 +475,9 @@ def fit_agent(self): with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights(self.global_steps) # Visualize some sample trajectories if batch is not None and len(batch) > 0: # Randomly select a few samples to visualize