-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Move rollout_func from _generate_single_turn to _generate
#5232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f10285e
7d2bb67
3b356ac
82c4508
3ea2fcf
445f4ba
8c6c88d
f617b2d
eaffd67
f3f6a5d
d417543
4b927d6
029fc1f
20b4039
b8e3912
07181cb
6ff1e56
9f340e4
d138be7
f033e63
5a1f609
1eb3540
d3f7971
319d52a
d5e1906
4ccadcf
0558dc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1216,25 +1216,6 @@ def _generate_single_turn(self, prompts: list): | |
| device = self.accelerator.device | ||
| mode = "train" if self.model.training else "eval" | ||
|
|
||
| if self.rollout_func is not None: | ||
| # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. | ||
| if self.use_vllm and self.state.global_step != self._last_loaded_step: | ||
| with profiling_context(self, "sync_weights"): | ||
| self.vllm_generation.sync_weights() | ||
| self._last_loaded_step = self.state.global_step | ||
|
|
||
| # Pass prompts to rollout_func preserving structured messages. | ||
| # Chat templating must happen inside rollout_func, at the backend boundary, so that | ||
| # multimodal content (images, typed content blocks) is not lost before rollout logic runs. | ||
| output = self.rollout_func(prompts, self) | ||
| required_keys = {"prompt_ids", "completion_ids", "logprobs"} | ||
| missing_keys = required_keys - output.keys() | ||
| if missing_keys: | ||
| missing_keys_list = sorted(missing_keys) | ||
| raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") | ||
| extra_fields = {k: v for k, v in output.items() if k not in required_keys} | ||
| return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields | ||
|
|
||
| # Generate completions using either vLLM or regular generation | ||
| if self.use_vllm: | ||
| # Sync weights if training step changed | ||
|
|
@@ -1521,7 +1502,26 @@ def _generate(self, prompts: list): | |
| # Copy the prompts to avoid modifying the original list | ||
| prompts = copy.deepcopy(prompts) | ||
|
|
||
| prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) | ||
| if self.rollout_func is not None: | ||
| # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. | ||
| if self.use_vllm and self.state.global_step != self._last_loaded_step: | ||
| with profiling_context(self, "sync_weights"): | ||
| self.vllm_generation.sync_weights() | ||
| self._last_loaded_step = self.state.global_step | ||
|
|
||
| # Pass prompts to rollout_func preserving structured messages. | ||
| # Chat templating must happen inside rollout_func, at the backend boundary, so that | ||
| # multimodal content (images, typed content blocks) is not lost before rollout logic runs. | ||
| output = self.rollout_func(prompts, self) | ||
| required_keys = {"prompt_ids", "completion_ids", "logprobs"} | ||
| missing_keys = required_keys - output.keys() | ||
| if missing_keys: | ||
| missing_keys_list = sorted(missing_keys) | ||
| raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") | ||
| extra_fields = {k: v for k, v in output.items() if k not in required_keys} | ||
| prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
After moving Useful? React with 👍 / 👎. |
||
| else: | ||
| prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) | ||
|
|
||
| # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. | ||
| if is_conversational({"prompt": prompts[0]}): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.