Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,40 @@ class GRPOConfig(TrainingArguments):
</Deprecated>

vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`):
Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed
logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL
Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework
(such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation
and training backends. TIS is proposed as a remedy for this issue.
Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM
completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied
regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines
how the IS ratios are computed and constrained.

vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`):
Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`.

The mode is defined along two orthogonal dimensions:
* Constraint: how to handle importance ratios above `vllm_importance_sampling_cap` (C):
- truncation: clip ratios from above, ρ ← min(ρ, C), as in
[Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl)
- masking: set ratios above C to zero, so those contributions do not affect the gradient, as in
[When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)

* Granularity: the level at which ratios are computed:
- token: per-token ratios ρ_t
- sequence: a single ratio ρ_seq per sequence, applied to all tokens

Supported options are:
- `"token_truncate"`:
Token-level truncated IS (default). Per-token ratios are clipped from above at C.
- `"token_mask"`:
Token-level masked IS. Per-token ratios above C are set to zero.
- `"sequence_truncate"`:
Sequence-level truncated IS. A single sequence ratio is clipped from above at C and
applied to all tokens in the sequence.
- `"sequence_mask"`:
Sequence-level masked IS. Sequences with ratios above C are masked out.

Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`):
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance
sampling ratio, improving training stability.
Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes,
importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C
are set to zero.

> Parameters that control the logging

Expand Down Expand Up @@ -665,18 +691,38 @@ class GRPOConfig(TrainingArguments):
vllm_importance_sampling_correction: bool = field(
default=True,
metadata={
"help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and "
"recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL "
"Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy "
"effects due to subtle implementation differences between generation and training backends. TIS is "
"proposed as a remedy for this issue."
"help": (
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
"Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM "
"completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied "
"regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how "
"IS ratios are computed and constrained."
)
},
)

vllm_importance_sampling_mode: str = field(
default="sequence_mask",
metadata={
"help": (
"Specifies how Importance Sampling (IS) is performed when "
"vllm_importance_sampling_correction=True. Modes are defined along two orthogonal "
"dimensions: (1) constraint, which determines how to handle ratios above "
"vllm_importance_sampling_cap (C)—either truncation (clip from above, ρ ← min(ρ, C)) or "
"masking (set ratios above C to zero); and (2) granularity, which determines whether "
"ratios are computed per token or as a single sequence-level ratio applied to all tokens. "
"Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and "
"'sequence_mask'."
)
},
)

vllm_importance_sampling_cap: float = field(
default=2.0,
metadata={
"help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the "
"importance sampling ratio, improving training stability."
"help": (
"Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, "
"ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero."
)
},
)

Expand Down
34 changes: 27 additions & 7 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def __init__(
self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode
self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
self.use_liger_kernel = args.use_liger_kernel
self.loss_type = args.loss_type
Expand Down Expand Up @@ -1577,10 +1578,29 @@ def _generate_and_score_completions(

# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
if self.use_vllm and self.vllm_importance_sampling_correction:
importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
importance_sampling_ratio = torch.clamp(
importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
token_logps_diff = old_per_token_logps - sampling_per_token_logps
token_logps_diff *= completion_mask

if self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]:
sequence_logps_diff = token_logps_diff.sum(dim=-1, keepdim=True)

vllm_importance_sampling_ratio = torch.exp(sequence_logps_diff)

# From here, vllm_importance_sampling_ratio's shape depends on
# vllm_importance_sampling_mode: "token_*" level: (B, T); "sequence_*" level: (B, 1)

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_importance_sampling_ratio = torch.clamp(
vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, 0.0
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
)
else:
raise ValueError(
f"Unknown vllm importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'."
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
)

# Compute the per-token log probabilities for the reference model
if self.beta != 0.0:
Expand Down Expand Up @@ -1702,7 +1722,7 @@ def _generate_and_score_completions(
self.accelerator.gather(max_delta).max().item()
)

flat_is_ratio = importance_sampling_ratio[completion_mask.bool()]
flat_is_ratio = token_logps_diff[completion_mask.bool()]
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
Expand Down Expand Up @@ -1733,7 +1753,7 @@ def _generate_and_score_completions(
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
if self.use_vllm and self.vllm_importance_sampling_correction:
output["importance_sampling_ratio"] = importance_sampling_ratio
output["vllm_importance_sampling_ratio"] = vllm_importance_sampling_ratio
Comment thread
LeonEricsson marked this conversation as resolved.
Outdated
if ref_per_token_logps is not None:
output["ref_per_token_logps"] = ref_per_token_logps
if "pixel_values" in forward_kwargs:
Expand Down Expand Up @@ -1887,7 +1907,7 @@ def _compute_loss(self, model, inputs):
per_token_loss = per_token_loss * entropy_mask

if self.use_vllm and self.vllm_importance_sampling_correction:
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
per_token_loss = per_token_loss * inputs["vllm_importance_sampling_ratio"]

if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
Expand Down
Loading