Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def __init__(
logprobs_mode="processed_logprobs",
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
self.llm.sleep(level=2)
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")

Expand Down Expand Up @@ -1098,7 +1098,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up()
self.llm.wake_up(tags=["weights"])

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down Expand Up @@ -1206,6 +1206,9 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
else:
vllm_inputs = all_prompts_text

if self.args.vllm_enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])

with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)

Expand All @@ -1231,7 +1234,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
logprobs = all_logprobs

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
self.llm.sleep(level=2)

elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
Expand Down
9 changes: 6 additions & 3 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def decode(example, tokenizer):
enable_sleep_mode=self.args.vllm_enable_sleep_mode,
)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
self.llm.sleep(level=2)
else:
raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.")

Expand Down Expand Up @@ -1094,7 +1094,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up()
self.llm.wake_up(tags=["weights"])

# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
Expand Down Expand Up @@ -1200,6 +1200,9 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
else:
vllm_inputs = all_prompts_text

if self.args.vllm_enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])

Comment thread
qgallouedec marked this conversation as resolved.
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)

Expand All @@ -1218,7 +1221,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
completion_ids = all_completion_ids

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
self.llm.sleep(level=2)

elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
Expand Down
Loading