Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions docs/advance/rollout_is_migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ actor_rollout_ref:

The new implementation:
- ✅ Three aggregation levels: token, sequence, geometric
- ✅ Two bounding modes: truncate, clip
- ✅ Two bounding modes: truncate, mask
- ✅ Dual threshold support (upper/lower)
- ✅ Veto mechanism for catastrophic outliers
- ✅ 30+ comprehensive metrics
Expand Down Expand Up @@ -150,7 +150,7 @@ Aggregation level for IS weights:
### `algorithm.rollout_is_mode` (str)
Bounding mode:
- `"truncate"`: Cap weights at upper threshold only
- `"clip"`: Zero out weights outside [lower, upper]
- `"mask"`: Zero out weights outside [lower, upper]

### `algorithm.rollout_is_veto_threshold` (float)
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
Expand Down Expand Up @@ -199,7 +199,7 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
- **`rollout_is_min`**: Minimum IS weight observed
- Shows the most underweighted token/sequence

- **`rollout_is_max`**: Maximum IS weight observed (before clipping)
- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
- Shows the most overweighted token/sequence
- Compare with `rollout_is_threshold` to see truncation impact

Expand Down Expand Up @@ -235,11 +235,11 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
#### **Threshold Exceedance Metrics**

- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
- Shows how often truncation/clipping occurs on high end
- Shows how often truncation/masking occurs on high end
- **Ideal value**: < 0.1 (most weights within bounds)

- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
- Shows how often clipping occurs on low end (clip mode only)
- Shows how often masking occurs on low end (mask mode only)
- **Ideal value**: < 0.1

#### **Sequence-Level Metrics** (for sequence/geometric modes)
Expand All @@ -261,14 +261,14 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear

- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold

#### **Clipping Metrics** (clip mode only)
#### **Masking Metrics** (mask mode only)

- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero)
- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
- **Ideal value**: < 0.1
- **Warning**: > 0.3 means losing too much data

- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token
- Shows sequence-level impact of clipping
- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
- Shows sequence-level impact of masking

#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)

Expand Down Expand Up @@ -456,14 +456,14 @@ algorithm:
rollout_is_mode: truncate
```

### Example 3: Geometric Mean with Clip
### Example 3: Geometric Mean with Mask
```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: clip
rollout_is_mode: mask
```

### Example 4: Asymmetric Thresholds
Expand All @@ -473,7 +473,7 @@ algorithm:
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: clip
rollout_is_mode: mask
```

## Troubleshooting
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Actor/Rollout/Reference Policy
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
rollout_is_level: token # Aggregation: token/sequence/geometric
rollout_is_mode: truncate # Bounding: truncate/clip
rollout_is_mode: truncate # Bounding: truncate/mask
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
Expand Down Expand Up @@ -527,7 +527,7 @@ Algorithm
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.

Expand Down
10 changes: 5 additions & 5 deletions examples/rollout_importance_sampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ algorithm:
rollout_is_mode: truncate
```

### Example 3: Geometric Mean with Clip
### Example 3: Geometric Mean with Mask

```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: clip
rollout_is_mode: mask
rollout_is_veto_threshold: 1e-4
```

Expand All @@ -118,7 +118,7 @@ algorithm:
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: clip
rollout_is_mode: mask
```

## Monitoring Metrics
Expand Down Expand Up @@ -183,9 +183,9 @@ These metrics help diagnose the distribution mismatch between rollout and traini
2. Verify rollout_log_probs are correctly passed
3. Check for systematic bias in rollout vs training

### Issue: Too Much Data Discarded (Clip Mode)
### Issue: Too Much Data Discarded (Mask Mode)

**Symptoms**: `rollout_is_clipped_fraction` > 0.5
**Symptoms**: `rollout_is_masked_fraction` > 0.5

**Solutions**:
1. Widen thresholds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rollout_is_threshold_lower=null
# Aggregation level: token | sequence | geometric (experimental)
rollout_is_level=token

# Bounding mode: truncate (cap upper) | clip (zero outside bounds)
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
rollout_is_mode=truncate

# Catastrophic outlier veto threshold
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/ppo/test_rollout_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ def test_basic_rollout_is():
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="geometric",
rollout_is_mode="clip",
rollout_is_mode="mask",
rollout_is_threshold=1.5,
rollout_is_threshold_lower=0.5,
rollout_is_veto_threshold=1e-4,
)

print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}")
print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
print(" ✓ Geometric mean mode passed")

# Test veto mechanism
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/ppo/test_rollout_is_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def test_all_aggregation_levels(self, sample_data):
assert "mismatch/rollout_is_mean" in metrics

def test_both_bounding_modes(self, sample_data):
"""Test both truncate and clip modes."""
modes = ["truncate", "clip"]
"""Test both truncate and mask modes."""
modes = ["truncate", "mask"]

for mode in modes:
_, metrics = compute_rollout_importance_weights(
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class AlgoConfig(BaseConfig):
float value = enabled (compute weights and metrics). This is the main on/off switch.
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds).
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ algorithm:
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token

# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate

# Per-token veto threshold for catastrophic outliers
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ algorithm:
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token

# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate

# Per-token veto threshold for catastrophic outliers
Expand Down
36 changes: 18 additions & 18 deletions verl/trainer/ppo/mismatch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

Key Features:
1. Three aggregation levels: token, sequence, geometric
2. Two handling modes: truncate (TIS), clip (CIS)
2. Two handling modes: truncate (TIS), mask (MIS)
3. Per-token veto mechanism for catastrophic outliers
4. Memory-efficient computation to prevent CUDA OOM
5. Comprehensive metrics tracking
Expand Down Expand Up @@ -77,9 +77,9 @@ def compute_rollout_importance_weights(
- "geometric": Geometric mean of ratios (experimental)
rollout_is_mode: How to handle weights exceeding threshold:
- "truncate": Cap weights at upper_threshold only (TIS)
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS)
- "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
rollout_is_threshold: Upper threshold for IS weights
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper)
rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
If None, veto mechanism is disabled.

Expand Down Expand Up @@ -179,32 +179,32 @@ def compute_rollout_importance_weights(
SAFETY_BOUND=SAFETY_BOUND,
)

# Step 3: Apply truncation or clipping based on mode
# Step 3: Apply truncation or masking based on mode
if rollout_is_mode == "truncate":
# Truncated IS (TIS): only cap upper bound to prevent overweighting
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)

elif rollout_is_mode == "clip":
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold]
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
clip_mask = clip_mask.float()
elif rollout_is_mode == "mask":
# Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
mask = mask.float()

# Track CIS-specific metrics
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask)
# Track MIS-specific metrics
metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)

# Sequence-level clipping fraction
# Sequence-level masking fraction
if rollout_is_level in ["sequence", "geometric"]:
# All tokens in a sequence have the same weight, so reuse clip_mask
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean()
# All tokens in a sequence have the same weight, so reuse mask
metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
else:
# Check if any token in each sequence is clipped
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean()
# Check if any token in each sequence is masked
seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()

rollout_is_weights = rollout_is_weights * clip_mask
rollout_is_weights = rollout_is_weights * mask

else:
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.")
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")

# Apply veto mask AFTER all thresholding
# This zeros out entire sequences that have any catastrophic token
Expand Down
Loading