[recipe] feat: add Single-stream Policy Optimization (SPO) algo.#4176
[recipe] feat: add Single-stream Policy Optimization (SPO) algo.#4176dzh19990407 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the Single-stream Policy Optimization (SPO) algorithm, a memory-efficient alternative to GRPO. The changes are extensive, adding new recipes, configurations, documentation, and core logic for SPO. My review focuses on correctness and potential bugs in the new implementation. I've identified a few critical and high-severity issues, including a logic error in handling parallel tool calls, brittle data parsing, a hardcoded hyperparameter, and a risky code modification in the sandbox tool. Addressing these will improve the robustness and maintainability of the new SPO recipe.
| async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]: | ||
| # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script | ||
| lines = code.split("\n") | ||
| for i, line in reversed(list(enumerate(lines))): | ||
| if line == "": | ||
| continue | ||
| if not lines[i].startswith("print"): | ||
| lines[i] = f"print({line})" | ||
| break | ||
| code = "\n".join(lines) | ||
|
|
||
| timeout = self.default_timeout | ||
| language = self.default_language | ||
| if not isinstance(code, str): | ||
| code = str(code) | ||
|
|
||
| result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) | ||
| # sandbox has no score or metrics, use Nones | ||
| return result, None, None |
There was a problem hiding this comment.
The execute method of CustomSandboxFusionTool modifies the user-provided code by attempting to add a print() statement to the last non-empty line. This is a very brittle approach and can lead to SyntaxError for valid Python code. For example, if the last line is an assignment (x = 1), it will be transformed into print(x = 1), which is invalid. This will cause the tool execution to fail unexpectedly. The model should be responsible for generating code that explicitly prints its output.
| async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]: | |
| # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script | |
| lines = code.split("\n") | |
| for i, line in reversed(list(enumerate(lines))): | |
| if line == "": | |
| continue | |
| if not lines[i].startswith("print"): | |
| lines[i] = f"print({line})" | |
| break | |
| code = "\n".join(lines) | |
| timeout = self.default_timeout | |
| language = self.default_language | |
| if not isinstance(code, str): | |
| code = str(code) | |
| result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) | |
| # sandbox has no score or metrics, use Nones | |
| return result, None, None | |
| async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]: | |
| timeout = self.default_timeout | |
| language = self.default_language | |
| if not isinstance(code, str): | |
| code = str(code) | |
| result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) | |
| # sandbox has no score or metrics, use Nones | |
| return result, None, None |
| response_ids = await self.loop.run_in_executor( | ||
| None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False) | ||
| ) |
There was a problem hiding this comment.
The current implementation only processes the response from the first tool call when multiple tools are executed in parallel. The responses list contains results from all parallel tool calls, but only responses[0] is used to generate response_ids. This will lead to incorrect behavior when max_parallel_calls is greater than one, as the outputs from other concurrent tool calls will be ignored. To fix this, you should process the responses from all tool calls by concatenating their text.
| response_ids = await self.loop.run_in_executor( | |
| None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False) | |
| ) | |
| all_responses_text = "".join([r.text for r in responses if r.text]) | |
| response_ids = await self.loop.run_in_executor( | |
| None, lambda: self.tokenizer.encode(all_responses_text, add_special_tokens=False) | |
| ) |
| try: | ||
| key = item["input"].split("user\n")[-1].split("\nassistant")[0].strip() | ||
| merged_prompt_to_scores[key].append(item["score"]) | ||
| except (KeyError, IndexError) as e: | ||
| print(f"Warning: Failed to parse item: {e}") | ||
| continue |
There was a problem hiding this comment.
The prompt extraction logic is brittle as it relies on split("user\n") and split("\nassistant"). This can easily break if the input format changes, for example, if there are extra newlines or variations in the template. Using a regular expression would be more robust. Please also add import re at the top of the file.
| try: | |
| key = item["input"].split("user\n")[-1].split("\nassistant")[0].strip() | |
| merged_prompt_to_scores[key].append(item["score"]) | |
| except (KeyError, IndexError) as e: | |
| print(f"Warning: Failed to parse item: {e}") | |
| continue | |
| try: | |
| # Using regex for more robust prompt extraction. | |
| match = re.search(r"user\n(.*?)\nassistant", item["input"], re.DOTALL) | |
| if not match: | |
| print(f"Warning: Could not extract prompt from item: {item.get('input', 'NO_INPUT_KEY')}") | |
| continue | |
| key = match.group(1).strip() | |
| merged_prompt_to_scores[key].append(item["score"]) | |
| except KeyError as e: | |
| print(f"Warning: Failed to parse item, missing key: {e}") | |
| continue |
| kl = (old_log_probs - cur_log_probs).abs() | ||
| D = masked_mean(kl, response_mask, axis=-1) # (M,) | ||
| rho_metrics["spo/D"] = D.mean().item() | ||
| D_half = torch.as_tensor(0.06, dtype=D.dtype, device=D.device) |
There was a problem hiding this comment.
The value 0.06 for D_half is hardcoded. This is a key hyperparameter for the KL-based rho decay and should be configurable rather than a magic number. Hardcoding it makes it difficult to tune and experiment with different values. Please make it a configurable parameter, for example by adding D_half: 0.06 to the rho config in recipe/spo/config/spo_trainer.yaml and reading it from there.
| D_half = torch.as_tensor(0.06, dtype=D.dtype, device=D.device) | |
| D_half = torch.as_tensor(self.config.trainer.spo.rho.get("D_half", 0.06), dtype=D.dtype, device=D.device) |
|
@wuxibin89 @PeterSH6 @vermouth1992 Could you please review this PR? I have adopted the previous suggestions and reduced the code. |
|
@dzh19990407 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? #4283 |
Sure, happy to! I’ll submit a PR to verl-project/verl-recipe. |
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
Here's a complete example combining all preprocessing and training steps:
Design & Code Changes
High-Level Design
SPO introduces two key innovations over standard GRPO (Group Relative Policy Optimization):
96 tokens per batch
maintaining sample efficiency
The architecture consists of three main components:
Specific Code Changes
1. Core SPO Training Implementation (
spo_ray_trainer.py:1160-1496)Sampling with Adaptive Weighting:
weight ∝ √(p̂(1-p̂)) + ερbased on KL divergenceKey Parameters:
Data Preprocessing (estimate_offline_values/split_dapo_into_subsets.py):
Value Generation (estimate_offline_values/eval.sh):
Merging (estimate_offline_values/merge_offline_values.py:44-134):
CustomRLHFDataset (lines 57-126):
compute_score (lines 128-151):
State Machine Architecture (lines 156-168):
Multi-Turn Code Execution (lines 253-278):
...markersKey Features:
GRPO Configuration:
train_batch_size=128 # 8 responses × 16 prompts
ppo_mini_batch_size=16
gen_batch_size=128
n_resp_per_prompt=8
SPO Configuration:
train_batch_size=1024 # 1 response × 1024 prompts (8× larger)
ppo_mini_batch_size=128 # 8× larger mini-batches
gen_batch_size=14000 # Large batch for efficient generation
n_resp_per_prompt=1 # Single response per prompt
offline_values=path/to/offline_values.json
Agent Loop Manager (spo_agent_loop.py:663-813):
Configuration System (config/spo_trainer.yaml):
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)