Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,8 @@ def _move_model_to_vllm(self):
if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode:
empty_cache()
self.vllm_engine.wake_up(tags=["weights"])
# Work around for https://github.com/vllm-project/vllm/issues/29341
self.vllm_engine.collective_rpc("reload_weights")
Comment thread
qgallouedec marked this conversation as resolved.

if is_peft_model(self.model):
# With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as
Expand Down
2 changes: 2 additions & 0 deletions trl/experimental/openenv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def generate_rollout_completions(

if trainer.args.vllm_enable_sleep_mode:
trainer.llm.wake_up(tags=["kv_cache"])
# Work around for https://github.com/vllm-project/vllm/issues/29341
trainer.llm.collective_rpc("reload_weights")

with profiling_context(trainer, "vLLM.generate_rollout"):
if as_chat:
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,8 @@ def _generate_single_turn(self, prompts: list):
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up(tags=["weights"])
# Work around for https://github.com/vllm-project/vllm/issues/29341
self.llm.collective_rpc("reload_weights")

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,8 @@ def _generate_single_turn(self, prompts: list):
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up(tags=["weights"])
# Work around for https://github.com/vllm-project/vllm/issues/29341
self.llm.collective_rpc("reload_weights")

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down
Loading