Skip to content
21 changes: 20 additions & 1 deletion docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with th
```python
from trl import GRPOConfig

config = GRPOConfig(
training_args = GRPOConfig(
loss_type="cispo",
epsilon_high=5.0,
num_generations=16,
Expand All @@ -367,6 +367,25 @@ config = GRPOConfig(
)
```

### Soft Adaptive Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2511.20347

Soft Adaptive Policy Optimization (SAPO), replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful learning signals. Compared with GSPO and GRPO, SAPO is both sequence-coherent and token-adaptive. Like GSPO, SAPO maintains sequence-level coherence, but its soft gating forms a continuous trust region that avoids the brittle hard clipping band used in GSPO.

To reproduce the paper's setting, use this configuration:

```python
from trl import GRPOConfig

training_args = GRPOConfig(
loss_type="sapo",
sapo_temperature_pos=1.0, # default value
sapo_temperature_neg=1.05, # default value
scale_rewards="group",
...
)
```


## Direct Policy Optimization
Expand Down
9 changes: 8 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,14 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"])
def test_get_sapo_token_loss(self):
sample_token_importance_ratio = torch.ones((2, 4))
sapo_token_loss = GRPOTrainer.get_sapo_token_loss(sample_token_importance_ratio, 1.0)
# sigmoid(temp * (1-1)) * 4/temp = 0.5 * 4 = 2
expected_sapo_token_loss = torch.full_like(sample_token_importance_ratio, 2.0)
torch.testing.assert_close(sapo_token_loss, expected_sapo_token_loss)

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo"])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
30 changes: 30 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ class GRPOConfig(TrainingArguments):
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL
paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`.
sapo_temperature_neg (`float`, *optional*, defaults to `1.05`):
Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. This parameter
is introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347).
sapo_temperature_pos (`float`, *optional*, defaults to `1.0`):
Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is
introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347).
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
Expand Down Expand Up @@ -211,6 +217,10 @@ class GRPOConfig(TrainingArguments):
clipped weights are then multiplied with the advantages and policy model's log probs. Individual token
losses are aggregated by normalizing with the number of active tokens in the global accumulated batch.
This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).
- `"sapo"`: Soft Adaptive Policy Optimization loss, as introduced in the [Soft Adaptive Policy Optimization
Comment thread
pramodith marked this conversation as resolved.
paper](https://huggingface.co/papers/2506.13585). Replaces hard clipping with a smooth,
temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful
learning signals.
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -583,6 +593,22 @@ class GRPOConfig(TrainingArguments):
"[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`."
},
)
sapo_temperature_neg: float = field(
default=1.05,
metadata={
"help": "Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. "
"This parameter is introduced in the [Soft Adaptive Policy Optimization "
"paper](https://huggingface.co/papers/2511.20347)."
},
)
sapo_temperature_pos: float = field(
default=1.0,
metadata={
"help": "Temperature for tokens with positive advantage scores used in the `sapo` loss function. "
"This parameter is introduced in the [Soft Adaptive Policy Optimization "
"paper](https://huggingface.co/papers/2511.20347)."
},
)
importance_sampling_level: str = field(
default="token",
metadata={
Expand Down Expand Up @@ -634,6 +660,10 @@ class GRPOConfig(TrainingArguments):
"Individual token losses are aggregated by normalizing with the number of active tokens in "
"the global accumulated batch. This method was introduced in the "
"[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."
"'sapo': Soft Adaptive Policy Optimization loss, as introduced in the "
Comment thread
pramodith marked this conversation as resolved.
"[Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2506.13585). "
"Replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates "
"off-policy updates while preserving useful learning signals."
},
)
mask_truncated_completions: bool = field(
Expand Down
24 changes: 23 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ def __init__(
"Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
)

if args.loss_type == "sapo" and (args.sapo_temperature_neg is None or args.sapo_temperature_pos is None):
raise ValueError(
"When using `sapo` loss, both `sapo_temperature_neg` and `sapo_temperature_pos` must be set."
)

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
Expand Down Expand Up @@ -1831,6 +1836,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
return self._compute_loss(model, inputs)

@staticmethod
def get_sapo_token_loss(unclipped_token_loss: torch.Tensor, temperature: float) -> torch.Tensor:
sigmoid_input = temperature * (unclipped_token_loss - 1)
sigmoid_smoothed_loss = torch.nn.functional.sigmoid(sigmoid_input)
sapo_token_loss = sigmoid_smoothed_loss * 4 / temperature
return sapo_token_loss

def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
Expand Down Expand Up @@ -1908,6 +1920,16 @@ def _compute_loss(self, model, inputs):
per_token_loss1 = coef_1 * advantages
per_token_loss2 = coef_2 * advantages
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
elif self.loss_type == "sapo":
per_token_loss = torch.empty_like(coef_1)
positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0
per_token_loss[positive_advantages_mask] = self.get_sapo_token_loss(
coef_1[positive_advantages_mask], self.args.sapo_temperature_pos
)
per_token_loss[~positive_advantages_mask] = self.get_sapo_token_loss(
coef_1[~positive_advantages_mask], self.args.sapo_temperature_neg
)
per_token_loss = -per_token_loss * advantages
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

Expand All @@ -1920,7 +1942,7 @@ def _compute_loss(self, model, inputs):
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl

if self.loss_type == "grpo":
if self.loss_type in ["grpo", "sapo"]:
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "bnpo":
Expand Down
Loading