Skip to content

refactor(rl): clean IS with prox_forward#169

Merged
Hecate0821 merged 1 commit intomainfrom
chengxili/clean-is-corrections
Mar 5, 2026
Merged

refactor(rl): clean IS with prox_forward#169
Hecate0821 merged 1 commit intomainfrom
chengxili/clean-is-corrections

Conversation

@Hecate0821
Copy link
Contributor

@Hecate0821 Hecate0821 commented Mar 5, 2026

Summary

  • Replace legacy TIS with AReaL-style decoupled IS: always pre-compute proximal logprobs via policy.forward("cross_entropy"), then use PPO ratio + behavioral weight
  • Unified PPO-clipped pattern across all 4 losses (GRPO switches from REINFORCE to PPO)
  • Remove all TIS dead code (ISConfig, make_tis_weights_fn, tis_enabled, tis_weights_fn)

Architecture

flowchart TD
    A["ref_forward → ref_logprobs"] --> B["prox_forward\npolicy.forward → prox_logprobs"]
    B --> C["split data into N minibatches"]
    C --> D["for each minibatch:\nforward_backward_custom\nratio = exp(pi_theta - prox)\nbehave = clamp(exp(prox - inf), cap)"]
    D --> E["optim_step (once)"]
Loading

Design references

  • AReaL : decoupled PPO with proximal policy recompute
  • VERL: compute_log_prob forward pass for old_log_probs before multi-epoch PPO

Test plan

  • 68 unit tests pass (all TIS tests replaced with decoupled IS tests)
  • No remaining TIS references in codebase
  • E2E smoke test with real deployment (deferred)

Made with Cursor

@Hecate0821 Hecate0821 force-pushed the chengxili/clean-is-corrections branch 8 times, most recently from 06d896f to 10d16b2 Compare March 5, 2026 02:26
Replace the legacy TIS system with two clean, always-on IS corrections:

1. PPO IS ratio: exp(pi_theta - pi_prox) with PPO-style clipping.
   Proximal logprobs pre-computed via policy.forward("cross_entropy").
2. TIS (train-inference IS) weight: exp(pi_prox - pi_old) with capping.
   Corrects for FP8/quantization gap between training and inference.
   Supports token-level and sequence-level aggregation (geometric mean).

Add ppo_n_minibatches config: N forward_backward calls per optim_step.
With N=1 (default): ratio=1, only TIS weight matters.
With N>1: proper multi-minibatch PPO like AReaL/VERL.

Removed: old ISConfig (clip_high/clip_low), make_tis_weights_fn,
TISWeightsFn, tis_enabled, tis_weights_fn parameter from all losses.

Informed by AReaL (arXiv:2505.24298) decoupled PPO objective and
VERL rollout IS framework (github.com/verl-project/verl/pull/3694).

Made-with: Cursor
@Hecate0821 Hecate0821 force-pushed the chengxili/clean-is-corrections branch from 10d16b2 to 2678c5b Compare March 5, 2026 02:32
@Hecate0821 Hecate0821 changed the title refactor(rl): clean IS with multi-minibatch PPO and prox_forward refactor(rl): clean IS with prox_forward Mar 5, 2026
@Hecate0821 Hecate0821 merged commit 227d25a into main Mar 5, 2026
@Hecate0821 Hecate0821 deleted the chengxili/clean-is-corrections branch March 5, 2026 21:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants