diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 66c71105ff..5b49db2b38 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -20,6 +20,7 @@ loss_fn: # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false use_importance_sampling_correction: false + sequence_level_importance_ratios: false token_level_loss: true checkpointing: diff --git a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml new file mode 100644 index 0000000000..35bd140c1a --- /dev/null +++ b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml @@ -0,0 +1,146 @@ +# GRPO Algorithm Configuration +grpo: + num_prompts_per_step: 128 + num_generations_per_prompt: 8 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + max_val_samples: 480 + val_batch_size: 32 + seed: 42 + +loss_fn: + reference_policy_kl_penalty: 0.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + sequence_level_importance_ratios: true + token_level_loss: false + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 10 + save_period: 10 + checkpoint_must_save_by: null + +policy: + model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 64 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 8192 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + cpu_offload: true + sequence_parallel: true + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + dynamic_batching: + enabled: False + + sequence_packing: + enabled: False + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 2.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: True + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + dataset_name: "DeepScaler" + shuffle: true + +env: + math: + num_workers: 16 + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 5f47ef3087..eb907a61de 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, TypedDict, TypeVar +from typing import Any, NotRequired, Optional, TypedDict, TypeVar import torch @@ -37,6 +37,11 @@ class ClippedPGLossConfig(TypedDict): use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool token_level_loss: bool + # If True, apply the off-policy importance-sampling correction at the + # sequence level (one weight per generated sample), as in GSPO. + # If False (default), correction is applied at the token level as in the + # original GRPO paper. + sequence_level_importance_ratios: NotRequired[bool] class ClippedPGLossDataDict(TypedDict): @@ -60,6 +65,7 @@ class ClippedPGLossFn(LossFunction): - PPO (Clipped) - https://arxiv.org/abs/1707.06347 - GRPO - https://arxiv.org/abs/2402.03300 - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 + - GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071 Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) @@ -101,10 +107,18 @@ def __init__(self, cfg: ClippedPGLossConfig): self.use_importance_sampling_correction = cfg[ "use_importance_sampling_correction" ] - + # Whether to compute importance weights per-sequence instead of per-token. + self.sequence_level_importance_ratios = cfg.get( + "sequence_level_importance_ratios", + False, + ) self.loss_type = ( LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL ) + if self.sequence_level_importance_ratios: + assert self.loss_type == LossType.SEQUENCE_LEVEL, ( + "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" + ) def __call__( self, @@ -204,7 +218,17 @@ def __call__( # Calculate clipped loss function if ppo ratio is enabled. if not self.disable_ppo_ratio: - ratios = (curr_logprobs - prev_logprobs).exp() + log_ratios = curr_logprobs - prev_logprobs + if self.sequence_level_importance_ratios: + seq_log_ratio_mean = masked_mean( + log_ratios, + token_mask, + dim=-1, + ).unsqueeze(-1) + seq_ratio = seq_log_ratio_mean.exp() + ratios = seq_ratio.repeat(1, advantages.shape[1]) + else: + ratios = log_ratios.exp() ratios_clamped = ratios.clamp( 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max ) @@ -228,11 +252,29 @@ def __call__( advantages < 0, torch.min(clip_loss, loss3), clip_loss ) + # ------------------------------------------------------------- + # Off-policy (actor) importance-sampling correction + # ------------------------------------------------------------- # See: docs/guides/grpo.md#importance-sampling-correction - actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs) - actor_importance_weights = torch.nan_to_num( - actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 - ) + if self.sequence_level_importance_ratios: + # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) + seq_lp_diff = ((prev_logprobs - generation_logprobs) * mask).sum(dim=-1) + actor_importance_weights = torch.exp(seq_lp_diff).detach() + actor_importance_weights = torch.nan_to_num( + actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) + # Broadcast to token dimension so we can reuse existing reduction + actor_importance_weights_expanded = actor_importance_weights.unsqueeze(-1) + else: + # Token-level correction + actor_importance_weights_expanded = torch.exp( + prev_logprobs - generation_logprobs + ) + actor_importance_weights_expanded = torch.nan_to_num( + actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 + ) + actor_importance_weights = actor_importance_weights_expanded + del actor_importance_weights_expanded if self.use_importance_sampling_correction: importance_weights_to_use = actor_importance_weights else: @@ -255,12 +297,20 @@ def __call__( global_normalization_factor=global_valid_seqs, ) + # Metric: sampling importance ratio (mean over samples) # See: docs/guides/grpo.md#sampling-importance-ratio - sample_importance_ratio = masked_mean( - actor_importance_weights, - mask, - global_normalization_factor=global_valid_toks, - ) + if self.sequence_level_importance_ratios: + sample_importance_ratio = masked_mean( + actor_importance_weights, + sample_mask, + global_normalization_factor=global_valid_seqs, + ) + else: + sample_importance_ratio = masked_mean( + actor_importance_weights, + mask, + global_normalization_factor=global_valid_toks, + ) # Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))] # See more details and other metrics in docs/guides/grpo.md#metrics diff --git a/tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh b/tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh new file mode 100755 index 0000000000..ce2adb1c51 --- /dev/null +++ b/tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh @@ -0,0 +1,67 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=40 +MAX_STEPS=40 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1" +fi + +# TODO: enable in subsequent PR to do a quick accuracy check +## Convert 8k checkpoint +#uv run examples/converters/convert_dcp_to_hf.py \ +# --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ +# --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ +# --hf-ckpt-path=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf +# +## Run eval +#uv run examples/run_eval.py \ +# generation.model_name=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf \ +# data.prompt_file=examples/prompts/cot.txt \ +# generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-8k +# +#cat ${RUN_LOG}.aime-8k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-8k-metric.json +# +#uv run tests/check_metrics.py ${RUN_LOG}-8k-metric.json \ +# 'data["score"] >= 0.25' \ +# +##uv run examples/run_eval.py \ +## generation.model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ +## data.prompt_file=examples/prompts/cot.txt \ +## generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-baseline +# +##cat ${RUN_LOG}.aime-baseline | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-baseline-metric.json +# +##uv run tests/check_metrics.py ${RUN_LOG}-baseline-metric.json \ +## 'data["score"] == 0.2' \ diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index b38557f2e8..300d6a2755 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -12,7 +12,6 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.sh # Megatron tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh -tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.sh # Functional 32b run tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.sh @@ -29,6 +28,9 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh +# Deepscaler (GSPO) +tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh + # GRPO math test run (32K context mcore) tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh diff --git a/tests/test_suites/release.txt b/tests/test_suites/release.txt index bd117a83f8..b9db763409 100644 --- a/tests/test_suites/release.txt +++ b/tests/test_suites/release.txt @@ -2,6 +2,9 @@ # GRPO # ######## +# Megatron (Qwen/Qwen2.5-7B-Instruct) +tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.sh + # Long 8b run tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.sh diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 65ed834625..764a849ed3 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools + import pytest import torch @@ -371,41 +373,45 @@ def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="c } ) # Return seq_len and vocab_size needed by tests - return data, seq_len, vocab_size + return data, batch_size, seq_len, vocab_size # Helper to create logits that yield specific target log probs after log_softmax -def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): +def _create_exact_logits( + target_curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device +): """Constructs logits such that log_softmax results in target_curr_lp_masked.""" dummy_logits = torch.full( - (1, seq_len, vocab_size), -100.0, device=device + (batch_size, seq_len, vocab_size), -100.0, device=device ) # Start very low # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] # We need to set logits for indices i=0..S-2 of the sliced logits tensor. # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. num_effective_pos = target_curr_lp_masked.shape[1] - for i in range(num_effective_pos): + for batch_idx, i in itertools.product(range(batch_size), range(num_effective_pos)): logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) data_idx = i + 1 # Index in the original input_ids to find the target token - target_token_id = input_ids[0, data_idx].item() + target_token_id = input_ids[batch_idx, data_idx].item() # Keep target_lp as a 0-dim tensor for torch ops - target_lp = target_curr_lp_masked[0, i] + target_lp = target_curr_lp_masked[batch_idx, i] # Handle target_lp = 0 case separately if torch.isclose(target_lp, torch.tensor(0.0, device=device)): - dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + dummy_logits[batch_idx, logit_idx, target_token_id] = ( + 100.0 # Large positive logit + ) elif target_lp < 0: # Set target token logit to 0 - dummy_logits[0, logit_idx, target_token_id] = 0.0 + dummy_logits[batch_idx, logit_idx, target_token_id] = 0.0 # Set one distractor token logit using the formula distractor_token_id = (target_token_id + 1) % vocab_size # Ensure distractor isn't same as target if vocab_size=1 (edge case) if distractor_token_id == target_token_id: distractor_token_id = (target_token_id + 2) % vocab_size distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) - dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + dummy_logits[batch_idx, logit_idx, distractor_token_id] = distractor_logit else: # target_lp > 0 is not supported by this method raise ValueError( "Target log probability must be negative or zero for this construction" @@ -420,7 +426,7 @@ def test_clipped_pg_loss_ppo_clipping(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 cfg = { @@ -486,7 +492,7 @@ def test_clipped_pg_loss_ppo_clipping(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -505,7 +511,7 @@ def test_clipped_pg_loss_reinforce_mode(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) cfg = { "disable_ppo_ratio": True, @@ -539,7 +545,7 @@ def test_clipped_pg_loss_reinforce_mode(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -560,7 +566,7 @@ def test_clipped_pg_loss_kl_penalty(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # --- Test Setup --- kl_beta = 0.1 @@ -607,7 +613,7 @@ def test_clipped_pg_loss_kl_penalty(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -632,7 +638,7 @@ def test_clipped_pg_loss_masking(): device = "cuda" # Use original loss function for masking tests, as it involves interactions # that the Testable class might obscure slightly. - data, seq_len, vocab_size = _setup_clipped_pg_test_data( + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data( batch_size=batch_size, seq_len=seq_len, device=device ) # Need some realistic-ish logits and logprobs for masking test @@ -734,7 +740,7 @@ def test_clipped_pg_loss_zero_mask(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) @@ -772,7 +778,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 kl_beta = 0.1 @@ -905,7 +911,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -954,7 +960,7 @@ def test_clipped_pg_loss_dual_clip(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) ratio_clip = 0.2 ratio_clip_c = 3.0 @@ -1034,7 +1040,7 @@ def test_clipped_pg_loss_dual_clip(): input_ids = data["input_ids"] dummy_logits = _create_exact_logits( - curr_lp_masked, input_ids, seq_len, vocab_size, device + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) actual_loss, _ = loss_fn( @@ -1054,7 +1060,7 @@ def test_clipped_pg_loss_entropy(): pytest.skip("No GPU available") device = "cuda" - data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) cfg = { "ratio_clip_min": 0.2, @@ -1092,7 +1098,7 @@ def test_clipped_pg_loss_entropy(): ) # torch.mean because default mask applies to these 3 terms dummy_logits = _create_exact_logits( - curr_lp_masked, data["input_ids"], seq_len, vocab_size, device + curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) _, metrics = loss_fn( dummy_logits, @@ -1107,3 +1113,300 @@ def test_clipped_pg_loss_entropy(): rtol=1e-3, atol=1e-5, ) + + +def test_clipped_pg_loss_gspo(): + """Tests GSPO path in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + log_ratios = curr_lp_masked - prev_lp_masked + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) + assert torch.allclose( + ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + assert torch.allclose( + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), + rtol=1e-3, + ) + + loss1 = -adv_masked * ratios + assert torch.allclose( + loss1, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped + assert torch.allclose( + loss2, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) + assert torch.allclose( + max_loss, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + expected_loss = torch.mean(max_loss) + assert torch.allclose( + expected_loss, torch.tensor(-0.6057, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss) + + +def test_clipped_pg_loss_gspo_batch_size_2(): + """Tests non-unit batch size GSPO path in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=2, device=device + ) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0], [1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor( + [[-1.0, -1.0, -1.0], [-2.0, -2.0, -2.0]], device=device + ) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453], [-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][:, 1:] = adv_masked + data["prev_logprobs"][:, 1:] = prev_lp_masked + + # --- Hand Calculation --- + log_ratios = curr_lp_masked - prev_lp_masked + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) + assert torch.allclose( + ratios, + torch.tensor( + [[0.9086, 0.9086, 0.9086], [2.4697, 2.4697, 2.4697]], device=device + ), + rtol=1e-3, + ) + + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + assert torch.allclose( + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086], [1.2, 1.2, 1.2]], device=device), + rtol=1e-3, + ) + + loss1 = -adv_masked * ratios + assert torch.allclose( + loss1, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-2.4697, 2.4697, -4.9394]], device=device + ), + rtol=1e-3, + ) + + loss2 = -adv_masked * ratios_clamped + assert torch.allclose( + loss2, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-1.2000, 1.2000, -2.4000]], device=device + ), + rtol=1e-3, + ) + + max_loss = torch.maximum(loss1, loss2) + assert torch.allclose( + max_loss, + torch.tensor( + [[-0.9086, 0.9086, -1.8171], [-1.2000, 2.4697, -2.4000]], device=device + ), + rtol=1e-3, + ) + + expected_loss = torch.mean(max_loss) + assert torch.allclose( + expected_loss, torch.tensor(-0.4912, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(1) * data["token_mask"] + ), + ) + torch.testing.assert_close(actual_loss, expected_loss) + + +def test_clipped_pg_loss_gspo_importance_sampling_correction(): + """Tests GSPO w/ importance sampling correction in ClippedPGLossFn.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_clip = 0.2 + cfg = { + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": True, + "sequence_level_importance_ratios": True, + "token_level_loss": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + + # For Importance Sampling + gen_lp_masked = torch.tensor([[-0.5, -1.5, -0.8]], device=device) + + # Fill full tensors + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["generation_logprobs"][0, 1:] = gen_lp_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + + # --- Hand Calculation --- + # Actor Loss Calculation + actor_importance_weights = torch.exp( + (prev_lp_masked - gen_lp_masked).sum(dim=-1).unsqueeze(-1) + ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] + assert torch.allclose( + actor_importance_weights, + torch.tensor([[0.8187]], device=device), + rtol=1e-3, + ) + + log_ratios = curr_lp_masked - prev_lp_masked + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, 3) + assert torch.allclose( + ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + assert torch.allclose( + ratios_clamped, + torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), + rtol=1e-3, + ) + + loss1 = -adv_masked * ratios + assert torch.allclose( + loss1, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped + assert torch.allclose( + loss2, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) + assert torch.allclose( + max_loss, torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), rtol=1e-3 + ) + + importance_weighted_max_loss = actor_importance_weights * max_loss + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.7439, 0.7439, -1.4877]], device=device), + rtol=1e-3, + ) + + expected_actor_loss = torch.mean(importance_weighted_max_loss) + assert torch.allclose( + expected_actor_loss, torch.tensor(-0.4959, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_actor_loss, atol=1e-4, rtol=1e-3)