Skip to content

[recipe] feat: add Single-stream Policy Optimization (SPO) algo.#4176

Closed
dzh19990407 wants to merge 1 commit intoverl-project:mainfrom
dzh19990407:251104_SPO_PR
Closed

[recipe] feat: add Single-stream Policy Optimization (SPO) algo.#4176
dzh19990407 wants to merge 1 commit intoverl-project:mainfrom
dzh19990407:251104_SPO_PR

Conversation

@dzh19990407
Copy link

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: New features, no similar PRs.
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

image

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

Here's a complete example combining all preprocessing and training steps:

# Step 1: Split dataset into subsets
python recipe/spo/estimate_offline_values/split_dapo_into_subsets.py \
    --dataset open-r1/DAPO-Math-17k-Processed \
    --output_dir DAPO-Math-17k-Processed_Splits \
    --num_subsets 5

# Step 2: Generate offline value estimates for each subset
for i in {0..4}; do
    OUTPUT_DIR=spo_verl_pr \
    DATA_FILE=DAPO-Math-17k-Processed_Splits/subset_${i}.parquet \
    MODEL_PATH=Qwen/Qwen3-8B \
    EXP_NAME=offline_value_estimation_subset_${i} \
    sh recipe/spo/estimate_offline_values/eval.sh
done

# Step 3: Merge offline value estimates
python recipe/spo/estimate_offline_values/merge_offline_values.py \
    --input_dir spo_verl_pr/offline_value_estimation \
    --output_file DAPO-Math-17k-Processed_Splits/offline_values.json

# Step 4: Train with SPO
OUTPUT_DIR=spo_verl_pr \
TRAIN_DATA_DIR=DAPO-Math-17k-Processed_Splits \
MODEL_PATH=Qwen/Qwen3-8B \
EXP_NAME=spo_training \
METHOD=SPO \
OFFLINE_VALUES=DAPO-Math-17k-Processed_Splits/offline_values.json \
sh recipe/spo/train.sh

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

This PR implements Single-stream Policy Optimization (SPO), an efficient reinforcement learning algorithm that
reduces GPU memory consumption by 8x compared to traditional GRPO while maintaining comparable performance. The
implementation is built on top of the VERL framework and includes a complete pipeline for math problem solving with
tool-augmented reasoning.

High-Level Design

SPO introduces two key innovations over standard GRPO (Group Relative Policy Optimization):

  1. Single-Response Generation: Generates 1 response per prompt instead of 8, reducing memory requirements from 768 to
    96 tokens per batch
  2. Sampling with Offline Values: Uses pretrained model estimates to intelligently select prompts for training,
    maintaining sample efficiency

The architecture consists of three main components:

  • Offline Value Estimation Pipeline: Preprocesses training data to estimate response quality
  • SPO Training Loop: Implements Thompson Sampling-based prompt selection and advantage estimation
  • Tool-Augmented Agent: Multi-turn reasoning with Python code execution capability

Specific Code Changes

1. Core SPO Training Implementation (spo_ray_trainer.py:1160-1496)

Sampling with Adaptive Weighting:

  • Maintains Beta distributions α and β for each prompt to model success probability
  • Samples prompts proportionally to uncertainty: weight ∝ √(p̂(1-p̂)) + ε
  • Updates distributions using adaptive decay factor ρ based on KL divergence
# Weighted sampling (lines 1184-1230)
prompt2phat = {k: α[k]/(α[k]+β[k]) for k in prompts}
prompt2weight = {k: √(*(1-)) + 0.05 for k in prompts}
selected_prompts = np.random.choice(prompts, size=batch_size, p=weights)

# Advantage estimation (lines 1346-1378)
advantages = reward - p_hats  # SPO advantage
advantages = (advantages - mean) / (std + 1e-8)  # Normalize

# Distribution updates (lines 1474-1489)
ρ = 2^(-D/D_half)  # Decay based on KL divergence
α_new = ρ*α + reward
β_new = ρ*β + (1-reward)

Key Parameters:

  • offline_N=8: Number of offline samples used for prior estimation
  • rho.type="kl": Adaptive decay based on policy drift
  • clip_lower=0.875: Minimum decay factor to prevent over-reliance on old data
  1. Offline Value Estimation Pipeline

Data Preprocessing (estimate_offline_values/split_dapo_into_subsets.py):

  • Splits large datasets (17k samples) into 5 subsets for parallel processing
  • Outputs .parquet files for distributed evaluation

Value Generation (estimate_offline_values/eval.sh):

  • Runs pretrained model on each subset to generate offline rewards
  • Uses same reward function as training for consistency
  • Stores results in validation_data/0.jsonl format

Merging (estimate_offline_values/merge_offline_values.py:44-134):

  • Aggregates scores from all subsets by prompt
  • Subsamples to max 8 scores per prompt using random selection
  • Outputs offline_values.json mapping prompts to score lists
  1. Custom Dataset and Reward Function (spo_retool.py)

CustomRLHFDataset (lines 57-126):

  • Processes multiple math datasets: AIME 2024/2025, DAPO-Math-17k, BeyondAIME
  • Appends answer format requirement: \boxed{answer} to all prompts
  • Maps datasets to consistent schema with data_source, reward_model, agent_name

compute_score (lines 128-151):

  • Validates reasoning format: exactly one tag, no code blocks after thinking
  • Extracts boxed answer using math_dapo.compute_score with strict verification
  • Returns binary reward (0 or 1) and predicted answer
  1. Tool-Augmented Agent Loop (agent_loop/spo_tool_agent_loop.py)

State Machine Architecture (lines 156-168):

  • PENDING: Prepares initial prompt with tool instructions
  • GENERATING: LLM generates reasoning/code
  • PROCESSING_TOOLS: Executes Python code in sandbox
  • TERMINATED: Completes trajectory

Multi-Turn Code Execution (lines 253-278):

  • Parses code blocks from LLM output using ... markers
  • Executes in isolated sandbox with timeout/security controls
  • Returns output wrapped in ... tags
  • Updates response_mask: 1 for LLM tokens, 0 for tool outputs

Key Features:

  • Stateless sandbox: each execution is independent
  • Truncation handling: limits tool outputs to max_tool_response_length (configurable)
  • Error recovery: gracefully handles execution failures
  1. Training Configuration (train.sh)

GRPO Configuration:
train_batch_size=128 # 8 responses × 16 prompts
ppo_mini_batch_size=16
gen_batch_size=128
n_resp_per_prompt=8

SPO Configuration:
train_batch_size=1024 # 1 response × 1024 prompts (8× larger)
ppo_mini_batch_size=128 # 8× larger mini-batches
gen_batch_size=14000 # Large batch for efficient generation
n_resp_per_prompt=1 # Single response per prompt
offline_values=path/to/offline_values.json

  1. Additional Modifications

Agent Loop Manager (spo_agent_loop.py:663-813):

  • Manages async rollout workers for parallel trajectory generation
  • Implements LLM server load balancing with sticky sessions
  • Computes performance metrics: generation time, tool call latency

Configuration System (config/spo_trainer.yaml):

  • Extends base PPO config with SPO-specific parameters
  • Configurable rho decay strategies: constant or KL-based
  • Integrated with Hydra for flexible experiment management

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Single-stream Policy Optimization (SPO) algorithm, a memory-efficient alternative to GRPO. The changes are extensive, adding new recipes, configurations, documentation, and core logic for SPO. My review focuses on correctness and potential bugs in the new implementation. I've identified a few critical and high-severity issues, including a logic error in handling parallel tool calls, brittle data parsing, a hardcoded hyperparameter, and a risky code modification in the sandbox tool. Addressing these will improve the robustness and maintainability of the new SPO recipe.

Comment on lines +33 to +51
async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]:
# NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script
lines = code.split("\n")
for i, line in reversed(list(enumerate(lines))):
if line == "":
continue
if not lines[i].startswith("print"):
lines[i] = f"print({line})"
break
code = "\n".join(lines)

timeout = self.default_timeout
language = self.default_language
if not isinstance(code, str):
code = str(code)

result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)
# sandbox has no score or metrics, use Nones
return result, None, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The execute method of CustomSandboxFusionTool modifies the user-provided code by attempting to add a print() statement to the last non-empty line. This is a very brittle approach and can lead to SyntaxError for valid Python code. For example, if the last line is an assignment (x = 1), it will be transformed into print(x = 1), which is invalid. This will cause the tool execution to fail unexpectedly. The model should be responsible for generating code that explicitly prints its output.

Suggested change
async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]:
# NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script
lines = code.split("\n")
for i, line in reversed(list(enumerate(lines))):
if line == "":
continue
if not lines[i].startswith("print"):
lines[i] = f"print({line})"
break
code = "\n".join(lines)
timeout = self.default_timeout
language = self.default_language
if not isinstance(code, str):
code = str(code)
result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)
# sandbox has no score or metrics, use Nones
return result, None, None
async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]:
timeout = self.default_timeout
language = self.default_language
if not isinstance(code, str):
code = str(code)
result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)
# sandbox has no score or metrics, use Nones
return result, None, None

Comment on lines +282 to +284
response_ids = await self.loop.run_in_executor(
None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation only processes the response from the first tool call when multiple tools are executed in parallel. The responses list contains results from all parallel tool calls, but only responses[0] is used to generate response_ids. This will lead to incorrect behavior when max_parallel_calls is greater than one, as the outputs from other concurrent tool calls will be ignored. To fix this, you should process the responses from all tool calls by concatenating their text.

Suggested change
response_ids = await self.loop.run_in_executor(
None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False)
)
all_responses_text = "".join([r.text for r in responses if r.text])
response_ids = await self.loop.run_in_executor(
None, lambda: self.tokenizer.encode(all_responses_text, add_special_tokens=False)
)

Comment on lines +88 to +93
try:
key = item["input"].split("user\n")[-1].split("\nassistant")[0].strip()
merged_prompt_to_scores[key].append(item["score"])
except (KeyError, IndexError) as e:
print(f"Warning: Failed to parse item: {e}")
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prompt extraction logic is brittle as it relies on split("user\n") and split("\nassistant"). This can easily break if the input format changes, for example, if there are extra newlines or variations in the template. Using a regular expression would be more robust. Please also add import re at the top of the file.

Suggested change
try:
key = item["input"].split("user\n")[-1].split("\nassistant")[0].strip()
merged_prompt_to_scores[key].append(item["score"])
except (KeyError, IndexError) as e:
print(f"Warning: Failed to parse item: {e}")
continue
try:
# Using regex for more robust prompt extraction.
match = re.search(r"user\n(.*?)\nassistant", item["input"], re.DOTALL)
if not match:
print(f"Warning: Could not extract prompt from item: {item.get('input', 'NO_INPUT_KEY')}")
continue
key = match.group(1).strip()
merged_prompt_to_scores[key].append(item["score"])
except KeyError as e:
print(f"Warning: Failed to parse item, missing key: {e}")
continue

kl = (old_log_probs - cur_log_probs).abs()
D = masked_mean(kl, response_mask, axis=-1) # (M,)
rho_metrics["spo/D"] = D.mean().item()
D_half = torch.as_tensor(0.06, dtype=D.dtype, device=D.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 0.06 for D_half is hardcoded. This is a key hyperparameter for the KL-based rho decay and should be configurable rather than a magic number. Hardcoding it makes it difficult to tune and experiment with different values. Please make it a configurable parameter, for example by adding D_half: 0.06 to the rho config in recipe/spo/config/spo_trainer.yaml and reading it from there.

Suggested change
D_half = torch.as_tensor(0.06, dtype=D.dtype, device=D.device)
D_half = torch.as_tensor(self.config.trainer.spo.rho.get("D_half", 0.06), dtype=D.dtype, device=D.device)

@dzh19990407
Copy link
Author

@wuxibin89 @PeterSH6 @vermouth1992 Could you please review this PR? I have adopted the previous suggestions and reduced the code.

@wuxibin89
Copy link
Collaborator

@dzh19990407 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? #4283

@dzh19990407
Copy link
Author

dzh19990407 commented Nov 25, 2025

@dzh19990407 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? #4283

Sure, happy to! I’ll submit a PR to verl-project/verl-recipe.

@wuxibin89 wuxibin89 closed this Nov 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants