Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 230 additions & 33 deletions docs/advance/rollout_is.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Actor/Rollout/Reference Policy
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
rollout_is_level: token # Aggregation: token/sequence/geometric
rollout_is_mode: truncate # Bounding: truncate/mask
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
rollout_is_veto_threshold: null # Catastrophic outlier threshold (null to disable)
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
Expand Down Expand Up @@ -519,7 +519,7 @@ Algorithm
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
rollout_is_veto_threshold: null # Disabled by default

- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
Expand All @@ -537,7 +537,7 @@ Algorithm
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is null (disabled).
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.

Trainer
Expand Down
12 changes: 6 additions & 6 deletions examples/rollout_importance_sampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bash examples/rollout_importance_sampling/run_with_rollout_is.sh
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: null, disabled)

## Configuration Examples

Expand All @@ -73,7 +73,7 @@ algorithm:
rollout_is: true # Apply to loss
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
rollout_is_veto_threshold: null # Disabled by default
```

### Example 2: Metrics Only (No Weight Application)
Expand All @@ -95,7 +95,7 @@ algorithm:
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: mask
rollout_is_veto_threshold: 1e-4
rollout_is_veto_threshold: 1e-4 # Enable veto for this example
```

### Example 4: Sequence-level with Truncate
Expand All @@ -107,7 +107,7 @@ algorithm:
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
rollout_is_level: sequence
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
rollout_is_veto_threshold: 1e-4 # Enable veto for this example
```

### Example 5: Asymmetric Thresholds
Expand Down Expand Up @@ -226,8 +226,8 @@ The veto mechanism zeros out entire sequences containing catastrophic outliers:

- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected
- This prevents extreme outliers from dominating training
- Default threshold: 1e-4 (ratio 10,000x off)
- Set to `null` to disable: `rollout_is_veto_threshold: null`
- Default: `null` (disabled by default)
- Set to `1e-4` to enable (catches ratios 10,000x off)

## Examples

Expand Down
4 changes: 2 additions & 2 deletions examples/rollout_importance_sampling/run_with_rollout_is.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ rollout_is_level=token
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
rollout_is_mode=truncate

# Catastrophic outlier veto threshold
rollout_is_veto_threshold=1e-4
# Catastrophic outlier veto threshold (set to null to disable, or e.g., 1e-4 to enable)
rollout_is_veto_threshold=null

# ==============================================================================
# Model and Data Configuration
Expand Down
81 changes: 72 additions & 9 deletions tests/trainer/ppo/test_rollout_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_basic_rollout_is():

# Test token-level truncate mode
print("\n1. Testing token-level truncate mode...")
weights_proto, metrics = compute_rollout_importance_weights(
weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
Expand All @@ -71,7 +71,7 @@ def test_basic_rollout_is():

# Test sequence-level mode
print("\n2. Testing sequence-level mode...")
weights_seq_proto, metrics_seq = compute_rollout_importance_weights(
weights_seq_proto, _, metrics_seq = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
Expand All @@ -92,7 +92,7 @@ def test_basic_rollout_is():

# Test geometric mean mode
print("\n3. Testing geometric mean mode...")
weights_geo_proto, metrics_geo = compute_rollout_importance_weights(
weights_geo_proto, _, metrics_geo = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
Expand All @@ -116,7 +116,7 @@ def test_basic_rollout_is():
rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7
eos_mask_veto = torch.ones(2, 5, device=device)

weights_veto_proto, metrics_veto = compute_rollout_importance_weights(
weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_importance_weights(
old_log_prob=old_log_prob_veto,
rollout_log_prob=rollout_log_prob_veto,
response_mask=eos_mask_veto,
Expand All @@ -128,21 +128,25 @@ def test_basic_rollout_is():

weights_veto = weights_veto_proto.batch["rollout_is_weights"]
print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}")
# Check that the sequence with catastrophic token has all weights zeroed
assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed"
assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed"
# KEY FIX: Veto is applied via response_mask, not by zeroing weights
# Check that weights are NON-ZERO (safety-bounded ratios preserved, not zeroed)
assert weights_veto[0].sum() > 0, "Weights should be non-zero (not zeroed by veto)"
# Check that response_mask has veto applied
assert modified_response_mask_veto[0].sum() == 0, "Vetoed sequence should have response_mask zeroed"
assert modified_response_mask_veto[1].sum() > 0, "Normal sequence should have response_mask unchanged"
print(" ✓ Veto mechanism passed")

# Test disabled IS (threshold=None)
print("\n5. Testing disabled IS...")
weights_disabled, metrics_disabled = compute_rollout_importance_weights(
weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_threshold=None,
)

assert weights_disabled is None, "Should return None when threshold is None"
assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged"
assert len(metrics_disabled) == 0, "Should return empty metrics when disabled"
print(" ✓ Disabled IS passed")

Expand All @@ -160,7 +164,7 @@ def test_metrics_completeness():
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
eos_mask = torch.ones(batch_size, seq_length, device=device)

_, metrics = compute_rollout_importance_weights(
_, _, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
Expand Down Expand Up @@ -264,6 +268,64 @@ def test_mismatch_metrics():
print(" ✓ Mismatch metrics work without rollout log probs")


def test_mask_mode():
"""Test mask mode applies rejection via response_mask, keeps true IS weights."""
print("\nTesting mask mode behavior...")

batch_size = 2
seq_length = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

# Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected)
# Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted)
old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device)
rollout_log_prob = torch.tensor(
[
[-1.0] * seq_length, # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37
[-2.5] * seq_length, # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65
],
device=device,
)
response_mask = torch.ones(batch_size, seq_length, device=device)

weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="mask",
rollout_is_threshold=2.0,
rollout_is_threshold_lower=0.5,
rollout_is_veto_threshold=None,
)

weights = weights_proto.batch["rollout_is_weights"]

# KEY FIX: Weights should be safety-bounded ratios (NOT zeroed)
assert torch.all(weights[0, :] > 0), "Weights should remain as safety-bounded ratios (not zeroed)"
assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), (
"First seq ratio should be ≈0.37"
)
assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), (
"Second seq ratio should be ≈1.65"
)

# Rejection should be applied via response_mask
assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask"
assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted"

# Verify mask metrics exist
assert "mismatch/rollout_is_masked_fraction" in metrics
assert abs(metrics["mismatch/rollout_is_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens"

print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)")
print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)")
print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)")
print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)")
print(f" Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.2f}")
print(" ✓ Mask mode correctly separates IS weights from rejection")


if __name__ == "__main__":
print("=" * 60)
print("Rollout Importance Sampling Test Suite")
Expand All @@ -273,6 +335,7 @@ def test_mismatch_metrics():
test_basic_rollout_is()
test_metrics_completeness()
test_mismatch_metrics()
test_mask_mode()
print("\n" + "=" * 60)
print("ALL TESTS PASSED ✓")
print("=" * 60)
Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/ppo/test_rollout_is_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
This test simulates that workflow.
"""
# First compute IS weights (as trainer would do centrally)
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
rollout_is_weights_proto, _, _ = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):

def test_rollout_is_weights_computation(self, sample_data):
"""Test rollout IS weights and metrics computation."""
weights_proto, metrics = compute_rollout_importance_weights(
weights_proto, _, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_all_aggregation_levels(self, sample_data):
levels = ["token", "sequence", "geometric"]

for level in levels:
_, metrics = compute_rollout_importance_weights(
_, _, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
Expand All @@ -136,7 +136,7 @@ def test_both_bounding_modes(self, sample_data):
modes = ["truncate", "mask"]

for mode in modes:
_, metrics = compute_rollout_importance_weights(
_, _, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_veto_mechanism(self):

response_mask = torch.ones(batch_size, seq_length, device=device)

_, metrics = compute_rollout_importance_weights(
_, _, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
Expand All @@ -196,7 +196,7 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
but rollout_is=False (disables weight application to policy loss).
"""
# Compute IS weights (as trainer would do)
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
rollout_is_weights_proto, _, is_metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ algorithm:
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is_veto_threshold: null
rollout_is: false
trainer:
balance_batch: true
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ algorithm:
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is_veto_threshold: null
rollout_is: false
trainer:
balance_batch: true
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class AlgoConfig(BaseConfig):
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
rollout_is_veto_threshold (float or None): Per-token veto threshold for catastrophic outliers. None to disable.
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
"""
Expand All @@ -99,7 +99,7 @@ class AlgoConfig(BaseConfig):
rollout_is_threshold_lower: Optional[float] = None
rollout_is_level: str = "token"
rollout_is_mode: str = "truncate"
rollout_is_veto_threshold: Optional[float] = 1e-4
rollout_is_veto_threshold: Optional[float] = None
# Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)
# True = apply weights to loss, False = compute metrics only (no weight application)
rollout_is: bool = False
4 changes: 2 additions & 2 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ algorithm:
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate

# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Per-token veto threshold for catastrophic outliers (null to disable)
rollout_is_veto_threshold: null

# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ algorithm:
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate

# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Per-token veto threshold for catastrophic outliers (null to disable)
rollout_is_veto_threshold: null

# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
Expand Down
9 changes: 6 additions & 3 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,13 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor
Expand Down
Loading