From d5731a82f9363125690695893b75436fa0839222 Mon Sep 17 00:00:00 2001 From: Tugrul Konuk Date: Mon, 28 Jul 2025 17:08:36 -0500 Subject: [PATCH 01/19] Added gspo Signed-off-by: Tugrul Konuk --- .../gspo_math_qwen30ba3b_megatron.yaml | 89 +++++++++++++++++++ nemo_rl/algorithms/loss_functions.py | 54 ++++++++--- 2 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 examples/configs/gspo_math_qwen30ba3b_megatron.yaml diff --git a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml new file mode 100644 index 0000000000..ecb824a298 --- /dev/null +++ b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml @@ -0,0 +1,89 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_1B_megatron.yaml" + +grpo: + num_prompts_per_step: 64 + num_generations_per_prompt: 32 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + token_level_loss: true + ratio_clip_c: null + sequence_level_importance_sampling: true + +policy: + model_name: "Qwen/Qwen3-30B-A3B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 4096 + precision: "bfloat16" + + dtensor_cfg: + enabled: false + + sequence_packing: + enabled: False + + optimizer: null # remove default FSDP optimizer + + scheduler: null # remove default FSDP scheduler + + megatron_cfg: + enabled: true + empty_unused_memory_level: 1 + converter_type: "LlamaForCausalLM" + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 8 + sequence_parallel: True + pipeline_dtype: ${policy.precision} + + optimizer: + optimizer: "adam" + lr: 3.0e-7 + min_lr: 3.0e-8 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 13 + lr_warmup_init: 3.0e-8 + + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.7 + enforce_eager: false + max_model_len: ${policy.max_total_sequence_length} + +cluster: + gpus_per_node: 8 + num_nodes: 8 diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 923e836554..0b2068b900 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -37,6 +37,10 @@ class ClippedPGLossConfig(TypedDict): use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool token_level_loss: bool + # If True, apply the off-policy importance-sampling correction at the + # sequence level (one weight per generated sample). If False (default) + # correction is applied at the token level as in the original GRPO paper. + sequence_level_importance_sampling: bool class ClippedPGLossDataDict(TypedDict): @@ -101,6 +105,11 @@ def __init__(self, cfg: ClippedPGLossConfig): self.use_importance_sampling_correction = cfg[ "use_importance_sampling_correction" ] + # Whether to compute importance weights per-sequence instead of per-token. + self.sequence_level_importance_sampling = cfg.get( + "sequence_level_importance_sampling", + False, + ) self.loss_type = ( LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL @@ -229,13 +238,27 @@ def __call__( advantages < 0, torch.min(clip_loss, loss3), clip_loss ) - # See: docs/guides/grpo.md#importance-sampling-correction - actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs) - actor_importance_weights = torch.nan_to_num( - actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 - ) + # ------------------------------------------------------------- + # Off-policy (actor) importance-sampling correction + # ------------------------------------------------------------- + if self.sequence_level_importance_sampling: + # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) + seq_lp_diff = ((prev_logprobs - generation_logprobs) * mask).sum(dim=-1) + actor_importance_weights = torch.exp(seq_lp_diff).detach() + actor_importance_weights = torch.nan_to_num( + actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) + # Broadcast to token dimension so we can reuse existing reduction + actor_importance_weights_expanded = actor_importance_weights.unsqueeze(-1) + else: + # Token-level correction (original GRPO implementation) + actor_importance_weights_expanded = torch.exp(prev_logprobs - generation_logprobs) + actor_importance_weights_expanded = torch.nan_to_num( + actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 + ) + if self.use_importance_sampling_correction: - importance_weights_to_use = actor_importance_weights + importance_weights_to_use = actor_importance_weights_expanded else: importance_weights_to_use = torch.ones_like(prev_logprobs) @@ -256,12 +279,19 @@ def __call__( global_normalization_factor=global_valid_seqs, ) - # See: docs/guides/grpo.md#sampling-importance-ratio - sample_importance_ratio = masked_mean( - actor_importance_weights, - mask, - global_normalization_factor=global_valid_toks, - ) + # Metric: sampling importance ratio (mean over samples) + if self.sequence_level_importance_sampling: + sample_importance_ratio = masked_mean( + actor_importance_weights, + sample_mask, + global_normalization_factor=global_valid_seqs, + ) + else: + sample_importance_ratio = masked_mean( + actor_importance_weights_expanded, + mask, + global_normalization_factor=global_valid_toks, + ) # Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))] # See more details and other metrics in docs/guides/grpo.md#metrics From e6ba491bbfb2e24ebdf87beaaa100831b8b03f64 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 31 Jul 2025 20:59:01 -0700 Subject: [PATCH 02/19] Port GSPO ratio clip from dev branch. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 0b2068b900..5ed36457db 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -64,6 +64,7 @@ class ClippedPGLossFn(LossFunction): - PPO (Clipped) - https://arxiv.org/abs/1707.06347 - GRPO - https://arxiv.org/abs/2402.03300 - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 + - GSPO (set sequence_level_importance_sampling = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) @@ -110,10 +111,14 @@ def __init__(self, cfg: ClippedPGLossConfig): "sequence_level_importance_sampling", False, ) - self.loss_type = ( LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL ) + # FIXME(pjin): sequence-level importance sampling must be exclusive to token-level loss. + if False and self.sequence_level_importance_sampling: + assert self.loss_type == LossType.SEQUENCE_LEVEL, ( + "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" + ) def __call__( self, @@ -214,7 +219,19 @@ def __call__( # Calculate clipped loss function if ppo ratio is enabled. if not self.disable_ppo_ratio: - ratios = (curr_logprobs - prev_logprobs).exp() + log_ratios = curr_logprobs - prev_logprobs + if self.sequence_level_importance_sampling: + # NB(pjin): this masked mean is only "sequence level" when micro batch size = 1. + assert micro_batch_size == 1 + seq_log_ratio = masked_mean( + log_ratios, + mask, + global_normalization_factor=global_valid_toks, + ) + seq_ratio = seq_log_ratio.exp() + ratios = seq_ratio.repeat(advantages.shape[0], advantages.shape[1]) + else: + ratios = log_ratios.exp() ratios_clamped = ratios.clamp( 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max ) @@ -256,9 +273,9 @@ def __call__( actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) - + actor_importance_weights = actor_importance_weights_expanded if self.use_importance_sampling_correction: - importance_weights_to_use = actor_importance_weights_expanded + importance_weights_to_use = actor_importance_weights else: importance_weights_to_use = torch.ones_like(prev_logprobs) From ce337d829b9ced2528f696e30cf3f553b653ef44 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 31 Jul 2025 21:11:24 -0700 Subject: [PATCH 03/19] Micro batch size. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 5ed36457db..b9b0e7f585 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -222,6 +222,7 @@ def __call__( log_ratios = curr_logprobs - prev_logprobs if self.sequence_level_importance_sampling: # NB(pjin): this masked mean is only "sequence level" when micro batch size = 1. + micro_batch_size = int(advantages.shape[0]) assert micro_batch_size == 1 seq_log_ratio = masked_mean( log_ratios, @@ -229,7 +230,7 @@ def __call__( global_normalization_factor=global_valid_toks, ) seq_ratio = seq_log_ratio.exp() - ratios = seq_ratio.repeat(advantages.shape[0], advantages.shape[1]) + ratios = seq_ratio.repeat(micro_batch_size, advantages.shape[1]) else: ratios = log_ratios.exp() ratios_clamped = ratios.clamp( From 76c842d6e6c99deef3d1ee5c54e5c9af06fe2fba Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 6 Aug 2025 14:16:28 -0700 Subject: [PATCH 04/19] Example GSPO config. Signed-off-by: Peter Jin --- .../gspo_math_qwen30ba3b_megatron.yaml | 83 +------------------ 1 file changed, 4 insertions(+), 79 deletions(-) diff --git a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml index ecb824a298..8c90fdd301 100644 --- a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml @@ -1,89 +1,14 @@ # GRPO Algorithm Configuration -defaults: "grpo_math_1B_megatron.yaml" - -grpo: - num_prompts_per_step: 64 - num_generations_per_prompt: 32 +defaults: "grpo_math_qwen30ba3b_megatron.yaml" loss_fn: reference_policy_kl_penalty: 0.01 ratio_clip_min: 0.2 ratio_clip_max: 0.2 + ratio_clip_c: null # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false use_importance_sampling_correction: false - token_level_loss: true - ratio_clip_c: null sequence_level_importance_sampling: true - -policy: - model_name: "Qwen/Qwen3-30B-A3B" - tokenizer: - name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - train_global_batch_size: 512 - train_micro_batch_size: 1 - generation_batch_size: 32 # Only used when generating using HF backend - logprob_batch_size: 4 - max_total_sequence_length: 4096 - precision: "bfloat16" - - dtensor_cfg: - enabled: false - - sequence_packing: - enabled: False - - optimizer: null # remove default FSDP optimizer - - scheduler: null # remove default FSDP scheduler - - megatron_cfg: - enabled: true - empty_unused_memory_level: 1 - converter_type: "LlamaForCausalLM" - tensor_model_parallel_size: 2 - pipeline_model_parallel_size: 1 - context_parallel_size: 1 - expert_tensor_parallel_size: 1 - expert_model_parallel_size: 8 - sequence_parallel: True - pipeline_dtype: ${policy.precision} - - optimizer: - optimizer: "adam" - lr: 3.0e-7 - min_lr: 3.0e-8 - weight_decay: 0.01 - bf16: true - fp16: false - params_dtype: "float32" - - scheduler: - start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - weight_decay_incr_style: "constant" - lr_decay_style: "constant" - lr_decay_iters: null - lr_warmup_iters: 13 - lr_warmup_init: 3.0e-8 - - env_vars: - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" - - generation: - backend: "vllm" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 1.0 - top_k: null - stop_token_ids: null - stop_strings: null - vllm_cfg: - tensor_parallel_size: 4 - gpu_memory_utilization: 0.7 - enforce_eager: false - max_model_len: ${policy.max_total_sequence_length} - -cluster: - gpus_per_node: 8 - num_nodes: 8 + # FIXME(pjin): should require token_level_loss = False. + token_level_loss: true From 8b65703dbd4cd8dc9b5cc8f8281ace6d2be399a3 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 13:13:12 -0700 Subject: [PATCH 05/19] Cleaner impl. Signed-off-by: Peter Jin --- .../gspo_math_qwen30ba3b_megatron.yaml | 9 ++++----- nemo_rl/algorithms/loss_functions.py | 20 +++++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml index 8c90fdd301..86c062e075 100644 --- a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml @@ -7,8 +7,7 @@ loss_fn: ratio_clip_max: 0.2 ratio_clip_c: null # (default off) loss formulation improvements (docs/guides/grpo.md#loss) - use_on_policy_kl_approximation: false - use_importance_sampling_correction: false - sequence_level_importance_sampling: true - # FIXME(pjin): should require token_level_loss = False. - token_level_loss: true + use_on_policy_kl_approximation: False + use_importance_sampling_correction: True + sequence_level_importance_sampling: True + token_level_loss: False diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index a8fda1bf2c..b31df3eb6d 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -114,8 +114,7 @@ def __init__(self, cfg: ClippedPGLossConfig): self.loss_type = ( LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL ) - # FIXME(pjin): sequence-level importance sampling must be exclusive to token-level loss. - if False and self.sequence_level_importance_sampling: + if self.sequence_level_importance_sampling: assert self.loss_type == LossType.SEQUENCE_LEVEL, ( "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) @@ -221,16 +220,13 @@ def __call__( if not self.disable_ppo_ratio: log_ratios = curr_logprobs - prev_logprobs if self.sequence_level_importance_sampling: - # NB(pjin): this masked mean is only "sequence level" when micro batch size = 1. - micro_batch_size = int(advantages.shape[0]) - assert micro_batch_size == 1 - seq_log_ratio = masked_mean( + seq_log_ratio_mean = masked_mean( log_ratios, - mask, - global_normalization_factor=global_valid_toks, + token_mask, + dim=-1, ) - seq_ratio = seq_log_ratio.exp() - ratios = seq_ratio.repeat(micro_batch_size, advantages.shape[1]) + seq_ratio = seq_log_ratio_mean.exp() + ratios = seq_ratio.repeat(1, advantages.shape[1]) else: ratios = log_ratios.exp() ratios_clamped = ratios.clamp( @@ -270,7 +266,9 @@ def __call__( actor_importance_weights_expanded = actor_importance_weights.unsqueeze(-1) else: # Token-level correction (original GRPO implementation) - actor_importance_weights_expanded = torch.exp(prev_logprobs - generation_logprobs) + actor_importance_weights_expanded = torch.exp( + prev_logprobs - generation_logprobs + ) actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) From 7938879a9cd8d58ee8db05dfe0354995cb5038dc Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 13:15:57 -0700 Subject: [PATCH 06/19] GSPO version of deepscaler example (WIP). Signed-off-by: Peter Jin --- examples/configs/gspo-deepscaler-1.5b-8K.yaml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 examples/configs/gspo-deepscaler-1.5b-8K.yaml diff --git a/examples/configs/gspo-deepscaler-1.5b-8K.yaml b/examples/configs/gspo-deepscaler-1.5b-8K.yaml new file mode 100644 index 0000000000..c169f229da --- /dev/null +++ b/examples/configs/gspo-deepscaler-1.5b-8K.yaml @@ -0,0 +1,7 @@ +# GRPO Algorithm Configuration +defaults: "grpo-deepscaler-1.5b-8K.yaml" + +loss_fn: + use_importance_sampling_correction: True + sequence_level_importance_sampling: True + token_level_loss: False From 9782097ade366f9a4dedba323c1bd0bd02a88eb9 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 14:13:12 -0700 Subject: [PATCH 07/19] Alternative 2-node TP=2 deepscaler GSPO config (WIP). Signed-off-by: Peter Jin --- ...-1.5b-8K.yaml => gspo-deepscaler-1.5b-8K-tp2.yaml} | 11 +++++++++++ 1 file changed, 11 insertions(+) rename examples/configs/{gspo-deepscaler-1.5b-8K.yaml => gspo-deepscaler-1.5b-8K-tp2.yaml} (50%) diff --git a/examples/configs/gspo-deepscaler-1.5b-8K.yaml b/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml similarity index 50% rename from examples/configs/gspo-deepscaler-1.5b-8K.yaml rename to examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml index c169f229da..8ce2ceee78 100644 --- a/examples/configs/gspo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml @@ -5,3 +5,14 @@ loss_fn: use_importance_sampling_correction: True sequence_level_importance_sampling: True token_level_loss: False + +policy: + activation_checkpointing_enabled: True + + dtensor_cfg: + tensor_parallel_size: 2 + activation_checkpointing: True + + generation: + vllm_cfg: + tensor_parallel_size: 2 From 95867ed2ec6353615fc6d2df8f25da664454ab89 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 7 Aug 2025 14:17:45 -0700 Subject: [PATCH 08/19] Remove Qwen3 GSPO config (will take too long to test). Signed-off-by: Peter Jin --- examples/configs/gspo_math_qwen30ba3b_megatron.yaml | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 examples/configs/gspo_math_qwen30ba3b_megatron.yaml diff --git a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml b/examples/configs/gspo_math_qwen30ba3b_megatron.yaml deleted file mode 100644 index 86c062e075..0000000000 --- a/examples/configs/gspo_math_qwen30ba3b_megatron.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# GRPO Algorithm Configuration -defaults: "grpo_math_qwen30ba3b_megatron.yaml" - -loss_fn: - reference_policy_kl_penalty: 0.01 - ratio_clip_min: 0.2 - ratio_clip_max: 0.2 - ratio_clip_c: null - # (default off) loss formulation improvements (docs/guides/grpo.md#loss) - use_on_policy_kl_approximation: False - use_importance_sampling_correction: True - sequence_level_importance_sampling: True - token_level_loss: False From c536aff641e3a101b6a09d9db60a844904047aba Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Fri, 8 Aug 2025 14:33:39 -0700 Subject: [PATCH 09/19] Rename sequence_level_importance_sampling => sequence_level_importance_ratios. Signed-off-by: Peter Jin --- .../configs/gspo-deepscaler-1.5b-8K-tp2.yaml | 2 +- nemo_rl/algorithms/loss_functions.py | 23 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml b/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml index 8ce2ceee78..a73cc3a5b1 100644 --- a/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml +++ b/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml @@ -3,7 +3,7 @@ defaults: "grpo-deepscaler-1.5b-8K.yaml" loss_fn: use_importance_sampling_correction: True - sequence_level_importance_sampling: True + sequence_level_importance_ratios: True token_level_loss: False policy: diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 4770aa1085..6560e1b274 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, TypedDict, TypeVar +from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch @@ -38,9 +38,10 @@ class ClippedPGLossConfig(TypedDict): use_importance_sampling_correction: bool token_level_loss: bool # If True, apply the off-policy importance-sampling correction at the - # sequence level (one weight per generated sample). If False (default) - # correction is applied at the token level as in the original GRPO paper. - sequence_level_importance_sampling: bool + # sequence level (one weight per generated sample), as in GSPO. + # If False (default), correction is applied at the token level as in the + # original GRPO paper. + sequence_level_importance_ratios: NotRequired[bool] class ClippedPGLossDataDict(TypedDict): @@ -64,7 +65,7 @@ class ClippedPGLossFn(LossFunction): - PPO (Clipped) - https://arxiv.org/abs/1707.06347 - GRPO - https://arxiv.org/abs/2402.03300 - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 - - GSPO (set sequence_level_importance_sampling = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 + - GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) @@ -107,14 +108,14 @@ def __init__(self, cfg: ClippedPGLossConfig): "use_importance_sampling_correction" ] # Whether to compute importance weights per-sequence instead of per-token. - self.sequence_level_importance_sampling = cfg.get( - "sequence_level_importance_sampling", + self.sequence_level_importance_ratios = cfg.get( + "sequence_level_importance_ratios", False, ) self.loss_type = ( LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL ) - if self.sequence_level_importance_sampling: + if self.sequence_level_importance_ratios: assert self.loss_type == LossType.SEQUENCE_LEVEL, ( "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) @@ -219,7 +220,7 @@ def __call__( # Calculate clipped loss function if ppo ratio is enabled. if not self.disable_ppo_ratio: log_ratios = curr_logprobs - prev_logprobs - if self.sequence_level_importance_sampling: + if self.sequence_level_importance_ratios: seq_log_ratio_mean = masked_mean( log_ratios, token_mask, @@ -255,7 +256,7 @@ def __call__( # ------------------------------------------------------------- # Off-policy (actor) importance-sampling correction # ------------------------------------------------------------- - if self.sequence_level_importance_sampling: + if self.sequence_level_importance_ratios: # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) seq_lp_diff = ((prev_logprobs - generation_logprobs) * mask).sum(dim=-1) actor_importance_weights = torch.exp(seq_lp_diff).detach() @@ -296,7 +297,7 @@ def __call__( ) # Metric: sampling importance ratio (mean over samples) - if self.sequence_level_importance_sampling: + if self.sequence_level_importance_ratios: sample_importance_ratio = masked_mean( actor_importance_weights, sample_mask, From 5e384dceefe3892f295b35375a888269ad8eeee3 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Sat, 9 Aug 2025 15:17:40 -0700 Subject: [PATCH 10/19] Using the new deepscaler config. Signed-off-by: Peter Jin --- .../configs/gspo-deepscaler-1.5b-8K-tp2.yaml | 18 ------------------ examples/configs/gspo-deepscaler-1.5b-8K.yaml | 7 +++++++ 2 files changed, 7 insertions(+), 18 deletions(-) delete mode 100644 examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml create mode 100644 examples/configs/gspo-deepscaler-1.5b-8K.yaml diff --git a/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml b/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml deleted file mode 100644 index a73cc3a5b1..0000000000 --- a/examples/configs/gspo-deepscaler-1.5b-8K-tp2.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# GRPO Algorithm Configuration -defaults: "grpo-deepscaler-1.5b-8K.yaml" - -loss_fn: - use_importance_sampling_correction: True - sequence_level_importance_ratios: True - token_level_loss: False - -policy: - activation_checkpointing_enabled: True - - dtensor_cfg: - tensor_parallel_size: 2 - activation_checkpointing: True - - generation: - vllm_cfg: - tensor_parallel_size: 2 diff --git a/examples/configs/gspo-deepscaler-1.5b-8K.yaml b/examples/configs/gspo-deepscaler-1.5b-8K.yaml new file mode 100644 index 0000000000..dbe3d64a20 --- /dev/null +++ b/examples/configs/gspo-deepscaler-1.5b-8K.yaml @@ -0,0 +1,7 @@ +# GRPO Algorithm Configuration +defaults: "recipes/llm/grpo-deepscaler-1.5b-8K.yaml" + +loss_fn: + use_importance_sampling_correction: false + sequence_level_importance_ratios: true + token_level_loss: false From 8882bfe692416fc43e7edfaa6e2da382ec5821f1 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Sat, 9 Aug 2025 15:21:51 -0700 Subject: [PATCH 11/19] Restore comment. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 6560e1b274..c6e3afd28b 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -256,6 +256,7 @@ def __call__( # ------------------------------------------------------------- # Off-policy (actor) importance-sampling correction # ------------------------------------------------------------- + # See: docs/guides/grpo.md#importance-sampling-correction if self.sequence_level_importance_ratios: # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) seq_lp_diff = ((prev_logprobs - generation_logprobs) * mask).sum(dim=-1) From b795e65c1c563b70b8216aa7befa1fc3d43a2693 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Sat, 9 Aug 2025 15:23:02 -0700 Subject: [PATCH 12/19] Restore comment. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index c6e3afd28b..a4c80769bd 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -298,6 +298,7 @@ def __call__( ) # Metric: sampling importance ratio (mean over samples) + # See: docs/guides/grpo.md#sampling-importance-ratio if self.sequence_level_importance_ratios: sample_importance_ratio = masked_mean( actor_importance_weights, From 931b744c92be164fe62215ea6ab8a04ccb4aac81 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 12 Aug 2025 14:19:59 -0700 Subject: [PATCH 13/19] Minor fix. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index a4c80769bd..b11845b786 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -275,6 +275,7 @@ def __call__( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) actor_importance_weights = actor_importance_weights_expanded + del actor_importance_weights_expanded if self.use_importance_sampling_correction: importance_weights_to_use = actor_importance_weights else: @@ -307,7 +308,7 @@ def __call__( ) else: sample_importance_ratio = masked_mean( - actor_importance_weights_expanded, + actor_importance_weights, mask, global_normalization_factor=global_valid_toks, ) From b90ea627b1e9991a9b33b97eeec0f4d0709a0107 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 13 Aug 2025 16:59:37 -0700 Subject: [PATCH 14/19] Remove inaccurate comment. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index b11845b786..a04e92dfe7 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -267,7 +267,7 @@ def __call__( # Broadcast to token dimension so we can reuse existing reduction actor_importance_weights_expanded = actor_importance_weights.unsqueeze(-1) else: - # Token-level correction (original GRPO implementation) + # Token-level correction actor_importance_weights_expanded = torch.exp( prev_logprobs - generation_logprobs ) From dfbd89da6e238b33fda465297f03e20c14f4f696 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Thu, 14 Aug 2025 14:19:38 -0700 Subject: [PATCH 15/19] GSPO config test. Signed-off-by: Peter Jin --- examples/configs/gspo-deepscaler-1.5b-8K.yaml | 7 - .../recipes/llm/gspo-deepscaler-1.5b-8K.yaml | 146 ++++++++++++++++++ .../llm/gspo-deepscaler-1.5b-8K.sh | 67 ++++++++ tests/test_suites/nightly.txt | 3 + 4 files changed, 216 insertions(+), 7 deletions(-) delete mode 100644 examples/configs/gspo-deepscaler-1.5b-8K.yaml create mode 100644 examples/configs/recipes/llm/gspo-deepscaler-1.5b-8K.yaml create mode 100755 tests/test_suites/llm/gspo-deepscaler-1.5b-8K.sh diff --git a/examples/configs/gspo-deepscaler-1.5b-8K.yaml b/examples/configs/gspo-deepscaler-1.5b-8K.yaml deleted file mode 100644 index dbe3d64a20..0000000000 --- a/examples/configs/gspo-deepscaler-1.5b-8K.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# GRPO Algorithm Configuration -defaults: "recipes/llm/grpo-deepscaler-1.5b-8K.yaml" - -loss_fn: - use_importance_sampling_correction: false - sequence_level_importance_ratios: true - token_level_loss: false diff --git a/examples/configs/recipes/llm/gspo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/gspo-deepscaler-1.5b-8K.yaml new file mode 100644 index 0000000000..35bd140c1a --- /dev/null +++ b/examples/configs/recipes/llm/gspo-deepscaler-1.5b-8K.yaml @@ -0,0 +1,146 @@ +# GRPO Algorithm Configuration +grpo: + num_prompts_per_step: 128 + num_generations_per_prompt: 8 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + max_val_samples: 480 + val_batch_size: 32 + seed: 42 + +loss_fn: + reference_policy_kl_penalty: 0.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + sequence_level_importance_ratios: true + token_level_loss: false + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 10 + save_period: 10 + checkpoint_must_save_by: null + +policy: + model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 64 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 8192 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + cpu_offload: true + sequence_parallel: true + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + dynamic_batching: + enabled: False + + sequence_packing: + enabled: False + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 2.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: True + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + dataset_name: "DeepScaler" + shuffle: true + +env: + math: + num_workers: 16 + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/tests/test_suites/llm/gspo-deepscaler-1.5b-8K.sh b/tests/test_suites/llm/gspo-deepscaler-1.5b-8K.sh new file mode 100755 index 0000000000..ce2adb1c51 --- /dev/null +++ b/tests/test_suites/llm/gspo-deepscaler-1.5b-8K.sh @@ -0,0 +1,67 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=40 +MAX_STEPS=40 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1" +fi + +# TODO: enable in subsequent PR to do a quick accuracy check +## Convert 8k checkpoint +#uv run examples/converters/convert_dcp_to_hf.py \ +# --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ +# --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ +# --hf-ckpt-path=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf +# +## Run eval +#uv run examples/run_eval.py \ +# generation.model_name=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf \ +# data.prompt_file=examples/prompts/cot.txt \ +# generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-8k +# +#cat ${RUN_LOG}.aime-8k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-8k-metric.json +# +#uv run tests/check_metrics.py ${RUN_LOG}-8k-metric.json \ +# 'data["score"] >= 0.25' \ +# +##uv run examples/run_eval.py \ +## generation.model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ +## data.prompt_file=examples/prompts/cot.txt \ +## generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-baseline +# +##cat ${RUN_LOG}.aime-baseline | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-baseline-metric.json +# +##uv run tests/check_metrics.py ${RUN_LOG}-baseline-metric.json \ +## 'data["score"] == 0.2' \ diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 07c3eb5b9c..9f2c9f49d0 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -18,6 +18,9 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh +# Deepscaler (GSPO) +tests/test_suites/llm/gspo-deepscaler-1.5b-8K.sh + ####### # SFT # ####### From d95dea056fab7ea2590a5bc6a161414ae8a3f020 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Tue, 19 Aug 2025 16:22:06 -0700 Subject: [PATCH 16/19] Batch size = 1 GSPO unit tests (should be equivalent to GRPO). Signed-off-by: Peter Jin --- tests/unit/algorithms/test_loss_functions.py | 191 +++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 65ed834625..262601fa40 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -1107,3 +1107,194 @@ def test_clipped_pg_loss_entropy(): rtol=1e-3, atol=1e-5, ) + + +def test_clipped_pg_loss_gspo(): + """Tests GSPO path in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "use_sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + log_ratios = curr_lp_masked - prev_lp_masked + ratios = torch.exp(log_ratios) # approx [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + assert torch.allclose( + expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss) + + +def test_clipped_pg_loss_gspo_importance_sampling_correction(): + """Tests GSPO w/ importance sampling correction in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": True, + "use_sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + + # For Importance Sampling + gen_lp_masked = torch.tensor([[-0.5, -1.5, -0.8]], device=device) + + # Fill full tensors + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["generation_logprobs"][0, 1:] = gen_lp_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + + # --- Hand Calculation --- + # Actor Loss Calculation + actor_importance_weights = torch.exp( + prev_lp_masked - gen_lp_masked + ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] + assert torch.allclose( + actor_importance_weights, + torch.tensor([[0.6065, 1.6487, 0.8187]], device=device), + rtol=1e-3, + ) + + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + + loss1 = -adv_masked * ratios # [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped # [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) # [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + importance_weighted_max_loss = ( + actor_importance_weights * max_loss + ) # [0.6065*(-0.5), 1.6487*1.0, 0.8187*(-2.4)] = [-0.30325, 1.6487, -1.96488] + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.30325, 1.6487, -1.96488]], device=device), + rtol=1e-3, + ) + + expected_actor_loss = torch.mean(importance_weighted_max_loss) # -0.2065 + assert torch.allclose( + expected_actor_loss, torch.tensor(-0.2065, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_actor_loss, atol=1e-4, rtol=1e-3) From c957966911ce23fab4ce9ce02b0978dd3e924721 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 20 Aug 2025 15:44:47 -0700 Subject: [PATCH 17/19] Clipped PG loss unit tests support for batch size > 1. Signed-off-by: Peter Jin --- tests/unit/algorithms/test_loss_functions.py | 142 +++++++++++++++---- 1 file changed, 115 insertions(+), 27 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 262601fa40..e1ff1efe51 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import pytest import torch @@ -371,41 +373,41 @@ def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="c } ) # Return seq_len and vocab_size needed by tests - return data, seq_len, vocab_size + return data, batch_size, seq_len, vocab_size # Helper to create logits that yield specific target log probs after log_softmax -def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): +def _create_exact_logits(target_curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device): """Constructs logits such that log_softmax results in target_curr_lp_masked.""" dummy_logits = torch.full( - (1, seq_len, vocab_size), -100.0, device=device + (batch_size, seq_len, vocab_size), -100.0, device=device ) # Start very low # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] # We need to set logits for indices i=0..S-2 of the sliced logits tensor. # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. num_effective_pos = target_curr_lp_masked.shape[1] - for i in range(num_effective_pos): + for k, i in itertools.product(range(batch_size), range(num_effective_pos)): logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) data_idx = i + 1 # Index in the original input_ids to find the target token - target_token_id = input_ids[0, data_idx].item() + target_token_id = input_ids[k, data_idx].item() # Keep target_lp as a 0-dim tensor for torch ops - target_lp = target_curr_lp_masked[0, i] + target_lp = target_curr_lp_masked[k, i] # Handle target_lp = 0 case separately if torch.isclose(target_lp, torch.tensor(0.0, device=device)): - dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + dummy_logits[k, logit_idx, target_token_id] = 100.0 # Large positive logit elif target_lp < 0: # Set target token logit to 0 - dummy_logits[0, logit_idx, target_token_id] = 0.0 + dummy_logits[k, logit_idx, target_token_id] = 0.0 # Set one distractor token logit using the formula distractor_token_id = (target_token_id + 1) % vocab_size # Ensure distractor isn't same as target if vocab_size=1 (edge case) if distractor_token_id == target_token_id: distractor_token_id = (target_token_id + 2) % vocab_size distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) - dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + dummy_logits[k, logit_idx, distractor_token_id] = distractor_logit else: # target_lp > 0 is not supported by this method raise ValueError( "Target log probability must be negative or zero for this construction" @@ -420,7 +422,7 @@ def test_clipped_pg_loss_ppo_clipping(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 cfg = { @@ -486,7 +488,7 @@ def test_clipped_pg_loss_ppo_clipping(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -505,7 +507,7 @@ def test_clipped_pg_loss_reinforce_mode(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) cfg = { "disable_ppo_ratio": True, @@ -539,7 +541,7 @@ def test_clipped_pg_loss_reinforce_mode(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -560,7 +562,7 @@ def test_clipped_pg_loss_kl_penalty(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # --- Test Setup --- kl_beta = 0.1 @@ -607,7 +609,7 @@ def test_clipped_pg_loss_kl_penalty(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -632,7 +634,7 @@ def test_clipped_pg_loss_masking(): device = "cuda" # Use original loss function for masking tests, as it involves interactions # that the Testable class might obscure slightly. - data, seq_len, vocab_size = _setup_clipped_pg_test_data( + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data( batch_size=batch_size, seq_len=seq_len, device=device ) # Need some realistic-ish logits and logprobs for masking test @@ -734,7 +736,7 @@ def test_clipped_pg_loss_zero_mask(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) @@ -772,7 +774,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 kl_beta = 0.1 @@ -905,7 +907,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -954,7 +956,7 @@ def test_clipped_pg_loss_dual_clip(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 ratio_clip_c = 3.0 @@ -1034,7 +1036,7 @@ def test_clipped_pg_loss_dual_clip(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -1054,7 +1056,7 @@ def test_clipped_pg_loss_entropy(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) cfg = { "ratio_clip_min": 0.2, @@ -1092,7 +1094,7 @@ def test_clipped_pg_loss_entropy(): ) # torch.mean because default mask applies to these 3 terms dummy_logits = _create_exact_logits( - curr_lp_masked, data["input_ids"], seq_len, vocab_size, device + curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) _, metrics = loss_fn( dummy_logits, @@ -1115,7 +1117,7 @@ def test_clipped_pg_loss_gspo(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 cfg = { @@ -1183,7 +1185,7 @@ def test_clipped_pg_loss_gspo(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -1195,13 +1197,99 @@ def test_clipped_pg_loss_gspo(): torch.testing.assert_close(actual_loss, expected_loss) +def test_clipped_pg_loss_gspo_bsz_2(): + """Tests GSPO path in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(batch_size=2, device=device) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "use_sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0], [1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453], [-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][:, 1:] = adv_masked + data["prev_logprobs"][:, 1:] = prev_lp_masked + + # --- Hand Calculation --- + log_ratios = curr_lp_masked - prev_lp_masked + ratios = torch.exp(log_ratios) # approx [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5], [0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2], [0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0], [-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4], [-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4], [-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + assert torch.allclose( + expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"].unsqueeze(1) * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss) + + def test_clipped_pg_loss_gspo_importance_sampling_correction(): """Tests GSPO w/ importance sampling correction in ClippedPGLossFn.""" if not torch.cuda.is_available(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 cfg = { @@ -1288,7 +1376,7 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( From a28c97416d850681b4df1adfe722c66ac6ec850d Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 20 Aug 2025 17:22:59 -0700 Subject: [PATCH 18/19] GSPO unit tests. Signed-off-by: Peter Jin --- nemo_rl/algorithms/loss_functions.py | 2 +- tests/unit/algorithms/test_loss_functions.py | 144 +++++++++++-------- 2 files changed, 84 insertions(+), 62 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 56ab3502b3..c58e0d719c 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -224,7 +224,7 @@ def __call__( log_ratios, token_mask, dim=-1, - ) + ).unsqueeze(-1) seq_ratio = seq_log_ratio_mean.exp() ratios = seq_ratio.repeat(1, advantages.shape[1]) else: diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index e1ff1efe51..d36324616e 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -377,7 +377,9 @@ def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="c # Helper to create logits that yield specific target log probs after log_softmax -def _create_exact_logits(target_curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device): +def _create_exact_logits( + target_curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device +): """Constructs logits such that log_softmax results in target_curr_lp_masked.""" dummy_logits = torch.full( (batch_size, seq_len, vocab_size), -100.0, device=device @@ -1128,7 +1130,7 @@ def test_clipped_pg_loss_gspo(): "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, - "use_sequence_level_importance_ratios": True, + "sequence_level_importance_ratios": True, "token_level_loss": False, } loss_fn = ClippedPGLossFn(cfg) @@ -1149,38 +1151,37 @@ def test_clipped_pg_loss_gspo(): # --- Hand Calculation --- log_ratios = curr_lp_masked - prev_lp_masked - ratios = torch.exp(log_ratios) # approx [0.5, 1.0, 1.5] + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) assert torch.allclose( - ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 ) - ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip - ) # [0.8, 1.0, 1.2] + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) assert torch.allclose( - ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), + rtol=1e-3, ) - loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss1 = -adv_masked * ratios assert torch.allclose( - loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + loss1, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + loss2 = -adv_masked * ratios_clamped assert torch.allclose( - loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + loss2, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) assert torch.allclose( - max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + max_loss, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - expected_loss = torch.mean( - max_loss - ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + expected_loss = torch.mean(max_loss) assert torch.allclose( - expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + expected_loss, torch.tensor(-0.6057, device=device), rtol=1e-3 ) input_ids = data["input_ids"] @@ -1197,13 +1198,15 @@ def test_clipped_pg_loss_gspo(): torch.testing.assert_close(actual_loss, expected_loss) -def test_clipped_pg_loss_gspo_bsz_2(): - """Tests GSPO path in ClippedPGLossFn.""" +def test_clipped_pg_loss_gspo_batch_size_2(): + """Tests non-unit batch size GSPO path in ClippedPGLossFn.""" if not torch.cuda.is_available(): pytest.skip("No GPU available") device = "cuda" - data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(batch_size=2, device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=2, device=device + ) ratio_clip = 0.2 cfg = { @@ -1214,14 +1217,16 @@ def test_clipped_pg_loss_gspo_bsz_2(): "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, - "use_sequence_level_importance_ratios": True, + "sequence_level_importance_ratios": True, "token_level_loss": False, } loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0], [1.0, -1.0, 2.0]], device=device) # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 - prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]], device=device) + prev_lp_masked = torch.tensor( + [[-1.0, -1.0, -1.0], [-2.0, -2.0, -2.0]], device=device + ) # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) # Curr = log(Ratio) + Prev @@ -1235,38 +1240,53 @@ def test_clipped_pg_loss_gspo_bsz_2(): # --- Hand Calculation --- log_ratios = curr_lp_masked - prev_lp_masked - ratios = torch.exp(log_ratios) # approx [0.5, 1.0, 1.5] + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) assert torch.allclose( - ratios, torch.tensor([[0.5, 1.0, 1.5], [0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ratios, + torch.tensor( + [[0.9086, 0.9086, 0.9086], [2.4697, 2.4697, 2.4697]], device=device + ), + rtol=1e-3, ) - ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip - ) # [0.8, 1.0, 1.2] + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) assert torch.allclose( - ratios_clamped, torch.tensor([[0.8, 1.0, 1.2], [0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086], [1.2, 1.2, 1.2]], device=device), + rtol=1e-3, ) - loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss1 = -adv_masked * ratios assert torch.allclose( - loss1, torch.tensor([[-0.5, 1.0, -3.0], [-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + loss1, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-2.4697, 2.4697, -4.9394]], device=device + ), + rtol=1e-3, ) - loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + loss2 = -adv_masked * ratios_clamped assert torch.allclose( - loss2, torch.tensor([[-0.8, 1.0, -2.4], [-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + loss2, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-1.2000, 1.2000, -2.4000]], device=device + ), + rtol=1e-3, ) - max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) assert torch.allclose( - max_loss, torch.tensor([[-0.5, 1.0, -2.4], [-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + max_loss, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-1.2000, 2.4697, -2.4000]], device=device + ), + rtol=1e-3, ) - expected_loss = torch.mean( - max_loss - ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + expected_loss = torch.mean(max_loss) assert torch.allclose( - expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + expected_loss, torch.tensor(-0.4912, device=device), rtol=1e-3 ) input_ids = data["input_ids"] @@ -1278,7 +1298,9 @@ def test_clipped_pg_loss_gspo_bsz_2(): dummy_logits, data, global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum(data["sample_mask"].unsqueeze(1) * data["token_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(1) * data["token_mask"] + ), ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1300,7 +1322,7 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": True, - "use_sequence_level_importance_ratios": True, + "sequence_level_importance_ratios": True, "token_level_loss": False, } loss_fn = ClippedPGLossFn(cfg) @@ -1325,53 +1347,53 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): # --- Hand Calculation --- # Actor Loss Calculation actor_importance_weights = torch.exp( - prev_lp_masked - gen_lp_masked + (prev_lp_masked - gen_lp_masked).sum(dim=-1).unsqueeze(-1) ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] assert torch.allclose( actor_importance_weights, - torch.tensor([[0.6065, 1.6487, 0.8187]], device=device), + torch.tensor([[0.8187]], device=device), rtol=1e-3, ) - ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + log_ratios = curr_lp_masked - prev_lp_masked + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) assert torch.allclose( - ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 ) - ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip - ) # [0.8, 1.0, 1.2] + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) assert torch.allclose( - ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), + rtol=1e-3, ) - loss1 = -adv_masked * ratios # [-0.5, 1.0, -3.0] + loss1 = -adv_masked * ratios assert torch.allclose( - loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + loss1, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - loss2 = -adv_masked * ratios_clamped # [-0.8, 1.0, -2.4] + loss2 = -adv_masked * ratios_clamped assert torch.allclose( - loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + loss2, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - max_loss = torch.maximum(loss1, loss2) # [-0.5, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) assert torch.allclose( - max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + max_loss, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 ) - importance_weighted_max_loss = ( - actor_importance_weights * max_loss - ) # [0.6065*(-0.5), 1.6487*1.0, 0.8187*(-2.4)] = [-0.30325, 1.6487, -1.96488] + importance_weighted_max_loss = actor_importance_weights * max_loss assert torch.allclose( importance_weighted_max_loss, - torch.tensor([[-0.30325, 1.6487, -1.96488]], device=device), + torch.tensor([[-0.7439, 0.7439, -1.4877]], device=device), rtol=1e-3, ) - expected_actor_loss = torch.mean(importance_weighted_max_loss) # -0.2065 + expected_actor_loss = torch.mean(importance_weighted_max_loss) assert torch.allclose( - expected_actor_loss, torch.tensor(-0.2065, device=device), rtol=1e-3 + expected_actor_loss, torch.tensor(-0.4959, device=device), rtol=1e-3 ) input_ids = data["input_ids"] From c69ae82ae353871addd65e7d05b2f4e912ce6c25 Mon Sep 17 00:00:00 2001 From: Peter Jin Date: Wed, 20 Aug 2025 17:33:41 -0700 Subject: [PATCH 19/19] Loop var name. Format. Signed-off-by: Peter Jin --- tests/unit/algorithms/test_loss_functions.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index d36324616e..764a849ed3 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -389,27 +389,29 @@ def _create_exact_logits( # We need to set logits for indices i=0..S-2 of the sliced logits tensor. # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. num_effective_pos = target_curr_lp_masked.shape[1] - for k, i in itertools.product(range(batch_size), range(num_effective_pos)): + for batch_idx, i in itertools.product(range(batch_size), range(num_effective_pos)): logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) data_idx = i + 1 # Index in the original input_ids to find the target token - target_token_id = input_ids[k, data_idx].item() + target_token_id = input_ids[batch_idx, data_idx].item() # Keep target_lp as a 0-dim tensor for torch ops - target_lp = target_curr_lp_masked[k, i] + target_lp = target_curr_lp_masked[batch_idx, i] # Handle target_lp = 0 case separately if torch.isclose(target_lp, torch.tensor(0.0, device=device)): - dummy_logits[k, logit_idx, target_token_id] = 100.0 # Large positive logit + dummy_logits[batch_idx, logit_idx, target_token_id] = ( + 100.0 # Large positive logit + ) elif target_lp < 0: # Set target token logit to 0 - dummy_logits[k, logit_idx, target_token_id] = 0.0 + dummy_logits[batch_idx, logit_idx, target_token_id] = 0.0 # Set one distractor token logit using the formula distractor_token_id = (target_token_id + 1) % vocab_size # Ensure distractor isn't same as target if vocab_size=1 (edge case) if distractor_token_id == target_token_id: distractor_token_id = (target_token_id + 2) % vocab_size distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) - dummy_logits[k, logit_idx, distractor_token_id] = distractor_logit + dummy_logits[batch_idx, logit_idx, distractor_token_id] = distractor_logit else: # target_lp > 0 is not supported by this method raise ValueError( "Target log probability must be negative or zero for this construction"