diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index b1c2eb1a1fb..fbd118435ca 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -358,7 +358,7 @@ You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with th ```python from trl import GRPOConfig -config = GRPOConfig( +training_args = GRPOConfig( loss_type="cispo", epsilon_high=5.0, num_generations=16, @@ -367,6 +367,25 @@ config = GRPOConfig( ) ``` +### Soft Adaptive Policy Optimization + +**📜 Paper**: https://huggingface.co/papers/2511.20347 + +Soft Adaptive Policy Optimization (SAPO), replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful learning signals. Compared with GSPO and GRPO, SAPO is both sequence-coherent and token-adaptive. Like GSPO, SAPO maintains sequence-level coherence, but its soft gating forms a continuous trust region that avoids the brittle hard clipping band used in GSPO. + +To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + loss_type="sapo", + sapo_temperature_pos=1.0, # default value + sapo_temperature_neg=1.05, # default value + scale_rewards="group", + ... +) +``` ## Direct Policy Optimization diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index baaea524b2c..1caaa3c4e6e 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -167,7 +167,14 @@ def test_training(self, config_name): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"]) + def test_get_sapo_token_loss(self): + sample_token_importance_ratio = torch.ones((2, 4)) + sapo_token_loss = GRPOTrainer.get_sapo_token_loss(sample_token_importance_ratio, 1.0) + # sigmoid(temp * (1-1)) * 4/temp = 0.5 * 4 = 2 + expected_sapo_token_loss = torch.full_like(sample_token_importance_ratio, 2.0) + torch.testing.assert_close(sapo_token_loss, expected_sapo_token_loss) + + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo"]) def test_training_loss_types(self, loss_type): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e386089dbca..3445c3dc522 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -172,6 +172,12 @@ class GRPOConfig(TrainingArguments): specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. + sapo_temperature_neg (`float`, *optional*, defaults to `1.05`): + Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. This parameter + is introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). + sapo_temperature_pos (`float`, *optional*, defaults to `1.0`): + Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is + introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -211,6 +217,10 @@ class GRPOConfig(TrainingArguments): clipped weights are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). + - `"sapo"`: Soft Adaptive Policy Optimization loss, as introduced in the [Soft Adaptive Policy Optimization + paper](https://huggingface.co/papers/2506.13585). Replaces hard clipping with a smooth, + temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful + learning signals. mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -583,6 +593,22 @@ class GRPOConfig(TrainingArguments): "[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`." }, ) + sapo_temperature_neg: float = field( + default=1.05, + metadata={ + "help": "Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. " + "This parameter is introduced in the [Soft Adaptive Policy Optimization " + "paper](https://huggingface.co/papers/2511.20347)." + }, + ) + sapo_temperature_pos: float = field( + default=1.0, + metadata={ + "help": "Temperature for tokens with positive advantage scores used in the `sapo` loss function. " + "This parameter is introduced in the [Soft Adaptive Policy Optimization " + "paper](https://huggingface.co/papers/2511.20347)." + }, + ) importance_sampling_level: str = field( default="token", metadata={ @@ -634,6 +660,10 @@ class GRPOConfig(TrainingArguments): "Individual token losses are aggregated by normalizing with the number of active tokens in " "the global accumulated batch. This method was introduced in the " "[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." + "'sapo': Soft Adaptive Policy Optimization loss, as introduced in the " + "[Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2506.13585). " + "Replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates " + "off-policy updates while preserving useful learning signals." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 995bfdc17ab..7e5c0d8c6bb 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -425,6 +425,11 @@ def __init__( "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." ) + if args.loss_type == "sapo" and (args.sapo_temperature_neg is None or args.sapo_temperature_pos is None): + raise ValueError( + "When using `sapo` loss, both `sapo_temperature_neg` and `sapo_temperature_pos` must be set." + ) + # Multi-step self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper self.epsilon_low = args.epsilon @@ -1831,6 +1836,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: return self._compute_loss(model, inputs) + @staticmethod + def get_sapo_token_loss(unclipped_token_loss: torch.Tensor, temperature: float) -> torch.Tensor: + sigmoid_input = temperature * (unclipped_token_loss - 1) + sigmoid_smoothed_loss = torch.nn.functional.sigmoid(sigmoid_input) + sapo_token_loss = sigmoid_smoothed_loss * 4 / temperature + return sapo_token_loss + def _compute_loss(self, model, inputs): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] @@ -1908,6 +1920,16 @@ def _compute_loss(self, model, inputs): per_token_loss1 = coef_1 * advantages per_token_loss2 = coef_2 * advantages per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + elif self.loss_type == "sapo": + per_token_loss = torch.empty_like(coef_1) + positive_advantages_mask = advantages.repeat([1, coef_1.shape[1]]) > 0 + per_token_loss[positive_advantages_mask] = self.get_sapo_token_loss( + coef_1[positive_advantages_mask], self.args.sapo_temperature_pos + ) + per_token_loss[~positive_advantages_mask] = self.get_sapo_token_loss( + coef_1[~positive_advantages_mask], self.args.sapo_temperature_neg + ) + per_token_loss = -per_token_loss * advantages else: raise ValueError(f"Unknown loss type: {self.loss_type}") @@ -1920,7 +1942,7 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - if self.loss_type == "grpo": + if self.loss_type in ["grpo", "sapo"]: loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo":