Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,7 @@ def _move_model_to_vllm(self):
if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode:
empty_cache()
self.vllm_engine.wake_up(tags=["weights"])
self.vllm_engine.collective_rpc("reload_weights")
Comment thread
qgallouedec marked this conversation as resolved.

if is_peft_model(self.model):
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
Expand Down
1 change: 1 addition & 0 deletions trl/experimental/openenv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def generate_rollout_completions(

if trainer.args.vllm_enable_sleep_mode:
trainer.llm.wake_up(tags=["kv_cache"])
trainer.llm.collective_rpc("reload_weights")

with profiling_context(trainer, "vLLM.generate_rollout"):
if as_chat:
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,7 @@ def _generate_single_turn(self, prompts: list):
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up(tags=["weights"])
self.llm.collective_rpc("reload_weights")

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def _generate_single_turn(self, prompts: list):
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up(tags=["weights"])
self.llm.collective_rpc("reload_weights")

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down
Loading