diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 370f70864d..2e24c3587e 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -43,6 +43,7 @@ loss_fn: # Async GRPO requires importance sampling correction enabled # Set to true when async_grpo.enabled is true use_importance_sampling_correction: false + truncated_importance_sampling_ratio: null sequence_level_importance_ratios: false token_level_loss: true diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index beec9ab870..6489b15f15 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -39,6 +39,7 @@ loss_fn: # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false use_importance_sampling_correction: false + truncated_importance_sampling_ratio: null token_level_loss: true checkpointing: diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 0c033319e5..f06da39f00 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -35,6 +35,7 @@ loss_fn: ratio_clip_c: null use_on_policy_kl_approximation: false use_importance_sampling_correction: false + truncated_importance_sampling_ratio: null token_level_loss: true checkpointing: enabled: true diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 2a3038ddbd..3e7f9baa94 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -42,6 +42,7 @@ class ClippedPGLossConfig(TypedDict): ratio_clip_c: float use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool + truncated_importance_sampling_ratio: float | None token_level_loss: bool # If True, apply the off-policy importance-sampling correction at the # sequence level (one weight per generated sample), as in GSPO. @@ -113,6 +114,9 @@ def __init__(self, cfg: ClippedPGLossConfig): self.use_importance_sampling_correction = cfg[ "use_importance_sampling_correction" ] + self.truncated_importance_sampling_ratio = cfg[ + "truncated_importance_sampling_ratio" + ] # Whether to compute importance weights per-sequence instead of per-token. self.sequence_level_importance_ratios = cfg.get( "sequence_level_importance_ratios", @@ -125,6 +129,13 @@ def __init__(self, cfg: ClippedPGLossConfig): assert self.loss_type == LossType.SEQUENCE_LEVEL, ( "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" ) + if self.truncated_importance_sampling_ratio is not None: + assert self.use_importance_sampling_correction, ( + "truncated_importance_sampling_ratio is only supported when use_importance_sampling_correction is True" + ) + assert self.truncated_importance_sampling_ratio > 0, ( + "truncated_importance_sampling_ratio should be positive" + ) def __call__( self, @@ -280,6 +291,12 @@ def __call__( actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) + # TIS see https://fengyao.notion.site/off-policy-rl + if self.truncated_importance_sampling_ratio is not None: + actor_importance_weights_expanded = torch.clamp( + actor_importance_weights_expanded, + max=self.truncated_importance_sampling_ratio, + ) actor_importance_weights = actor_importance_weights_expanded del actor_importance_weights_expanded if self.use_importance_sampling_correction: diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 62f4153fdf..34ed2ef88f 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -889,6 +889,8 @@ def val_iter(self): "ratio_clip_c": 1.0, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, "token_level_loss": True, } ) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 3978331795..3f93f36442 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +from copy import deepcopy import pytest import torch from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossConfig, ClippedPGLossFn, DistillationLossFn, DPOLossFn, @@ -25,6 +27,19 @@ from nemo_rl.algorithms.utils import masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict +basic_pg_loss_test_config: ClippedPGLossConfig = { + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "ratio_clip_c": None, + "disable_ppo_ratio": False, + "reference_policy_kl_penalty": 0.0, # Disable KL + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, # Disable TIS + "sequence_level_importance_ratios": False, + "token_level_loss": True, +} + def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): seq_len = 4 @@ -429,17 +444,7 @@ def test_clipped_pg_loss_ppo_clipping(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_clip = 0.2 - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.0, # Disable KL - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = basic_pg_loss_test_config loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -463,7 +468,7 @@ def test_clipped_pg_loss_ppo_clipping(): ) ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] ) # [0.8, 1.0, 1.2] assert torch.allclose( ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 @@ -514,16 +519,10 @@ def test_clipped_pg_loss_reinforce_mode(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - cfg = { - "disable_ppo_ratio": True, - "reference_policy_kl_penalty": 0.0, - "ratio_clip_min": 0.0, # Placeholder, ignored - "ratio_clip_max": 0.0, # Placeholder, ignored - "ratio_clip_c": None, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["disable_ppo_ratio"] = True + cfg["ratio_clip_min"] = 0.0 + cfg["ratio_clip_max"] = 0.0 loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -570,17 +569,8 @@ def test_clipped_pg_loss_kl_penalty(): data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # --- Test Setup --- - kl_beta = 0.1 - cfg = { - "reference_policy_kl_penalty": kl_beta, - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) @@ -609,7 +599,7 @@ def test_clipped_pg_loss_kl_penalty(): expected_kl_mean, torch.tensor(0.362, device=device), rtol=1e-3 ) - expected_loss = kl_beta * expected_kl_mean # 0.0362 + expected_loss = cfg["reference_policy_kl_penalty"] * expected_kl_mean # 0.0362 assert torch.allclose(expected_loss, torch.tensor(0.0362, device=device), rtol=1e-3) input_ids = data["input_ids"] @@ -652,16 +642,8 @@ def test_clipped_pg_loss_masking(): # Make advantages non-zero data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 - cfg = { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.1, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn # --- Test 1: Token Mask --- @@ -745,16 +727,8 @@ def test_clipped_pg_loss_zero_mask(): # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) - cfg = { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.1, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn # Set token mask to all zeros @@ -781,19 +755,9 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_clip = 0.2 - kl_beta = 0.1 - - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": None, - "reference_policy_kl_penalty": kl_beta, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": True, - "use_importance_sampling_correction": True, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_on_policy_kl_approximation"] = True + cfg["use_importance_sampling_correction"] = True loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -830,7 +794,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): ) ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] ) # [0.8, 1.0, 1.2] assert torch.allclose( ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 @@ -904,7 +868,9 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): expected_kl_mean = torch.mean( importance_weighted_kl_term_per_token ) # mean([0.09308, 0.0, 0.08855]) = 0.060543 - expected_kl_loss = kl_beta * expected_kl_mean # 0.1 * 0.060543 = 0.0060543 + expected_kl_loss = ( + cfg["reference_policy_kl_penalty"] * expected_kl_mean + ) # 0.1 * 0.060543 = 0.0060543 expected_total_loss = ( expected_actor_loss + expected_kl_loss @@ -924,6 +890,137 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): torch.testing.assert_close(actual_loss, expected_total_loss, atol=1e-4, rtol=1e-3) +@pytest.mark.parametrize("sequence_level_importance_ratios", [True, False]) +def test_clipped_pg_loss_on_policy_truncated_importance_sampling( + sequence_level_importance_ratios, +): + """Tests PPO loss with truncated importance sampling enabled.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_importance_sampling_correction"] = True + cfg["truncated_importance_sampling_ratio"] = 0.8 + if sequence_level_importance_ratios: + cfg["sequence_level_importance_ratios"] = True + cfg["token_level_loss"] = False + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # approx log(0.5)-1, log(1)-1, log(1.5)-1 + curr_lp_masked = torch.tensor([[-1.69315, -1.0, -0.59453]], device=device) + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # for importance sampling + gen_lp_masked = torch.tensor([[-0.5, -1.5, -0.8]], device=device) + + # Fill full tensors + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["generation_logprobs"][0, 1:] = gen_lp_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + + # --- Hand Calculation --- + + # sequence-level: [[0.9086, 0.9086, 0.9086]] + # token-level: [[0.5, 1.0, 1.5]] + if sequence_level_importance_ratios: + log_ratios = curr_lp_masked - prev_lp_masked + seq_log_ratios_mean = torch.mean(log_ratios, dim=-1).unsqueeze(-1) + ratios = seq_log_ratios_mean.exp().repeat(1, adv_masked.shape[1]) + else: + ratios = torch.exp(curr_lp_masked - prev_lp_masked) + + # sequence-level: [[0.9086, 0.9086, 0.9086]] + # token-level: [[0.8, 1.0, 1.2]] + clip_min = cfg["ratio_clip_min"] + clip_max = cfg["ratio_clip_max"] + ratios_clamped = torch.clamp(ratios, 1.0 - clip_min, 1.0 + clip_max) + + # sequence-level: [[-0.9086, 0.9086, -1.8171]] + # token-level: [[-0.5, 1.0, -3.0]] + loss1 = -adv_masked * ratios + + # sequence-level: [[-0.9086, 0.9086, -1.8171]] + # token-level: [[-0.8, 1.0, -2.4]] + loss2 = -adv_masked * ratios_clamped + + # sequence-level: [[-0.9086, 0.9086, -1.8171]] + # token-level: [[-0.5, 1.0, -2.4]] + max_loss = torch.maximum(loss1, loss2) + if sequence_level_importance_ratios: + assert torch.allclose( + max_loss, + torch.tensor([[-0.9086, 0.9086, -1.8171]], device=device), + rtol=1e-3, + ) + else: + assert torch.allclose( + max_loss, + torch.tensor([[-0.5, 1.0, -2.4]], device=device), + rtol=1e-3, + ) + + # sequence-level: [[0.8187]] + # token-level: [[0.6065, 1.6487, 0.8187]] + if sequence_level_importance_ratios: + actor_importance_weights = torch.exp( + (prev_lp_masked - gen_lp_masked).sum(dim=-1).unsqueeze(-1) + ) + else: + actor_importance_weights = torch.exp(prev_lp_masked - gen_lp_masked) + + # sequence-level: [[0.8000]] + # token-level: [[0.6065, 0.8000, 0.8000]] + truncated_actor_importance_weights = torch.clamp( + actor_importance_weights, max=cfg["truncated_importance_sampling_ratio"] + ) + + # sequence-level: [[-0.7268, 0.7268, -1.4537]] + # token-level: [[-0.3033, 0.8000, -1.9200]] + importance_weighted_max_loss = truncated_actor_importance_weights * max_loss + if sequence_level_importance_ratios: + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.7268, 0.7268, -1.4537]], device=device), + rtol=1e-3, + ) + else: + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.3033, 0.8000, -1.9200]], device=device), + rtol=1e-3, + ) + + # sequence-level: -0.4846 + # token-level: -0.4744 + expected_loss = torch.mean(importance_weighted_max_loss) + if sequence_level_importance_ratios: + assert torch.allclose( + expected_loss, torch.tensor(-0.4846, device=device), rtol=1e-3 + ) + else: + assert torch.allclose( + expected_loss, torch.tensor(-0.4744, device=device), rtol=1e-3 + ) + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) + + def test_masked_mean_all_zeros(): """Test masked_mean function with all zeros mask.""" values = torch.tensor([1.0, 2.0, 3.0, 4.0]) @@ -963,18 +1060,8 @@ def test_clipped_pg_loss_dual_clip(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_clip = 0.2 - ratio_clip_c = 3.0 - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": ratio_clip_c, - "reference_policy_kl_penalty": 0.0, # Disable KL - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["ratio_clip_c"] = 3.0 loss_fn = ClippedPGLossFn(cfg) # Create test data with a mix of advantages: positive, slightly negative, strongly negative @@ -998,7 +1085,7 @@ def test_clipped_pg_loss_dual_clip(): # --- Hand Calculation --- # Actor Loss Calculation ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_clip, 1.0 + ratio_clip + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] ) # [0.8, 1.0, 1.2] assert torch.allclose( ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 @@ -1021,7 +1108,9 @@ def test_clipped_pg_loss_dual_clip(): ) # Dual clipping - loss3 = -adv_masked * ratio_clip_c # -[1*3.0, -1*3.0, -4*3.0] = [-3.0, 3.0, 12.0] + loss3 = ( + -adv_masked * cfg["ratio_clip_c"] + ) # -[1*3.0, -1*3.0, -4*3.0] = [-3.0, 3.0, 12.0] assert torch.allclose( loss3, torch.tensor([[-3.0, 3.0, 12.0]], device=device), rtol=1e-3 ) @@ -1063,16 +1152,7 @@ def test_clipped_pg_loss_entropy(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - cfg = { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.0, # Disable KL for simplicity - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, # This flag does not affect entropy calculation - "token_level_loss": True, - } + cfg = basic_pg_loss_test_config loss_fn = ClippedPGLossFn(cfg) # Log probs for 3 tokens (default token_mask is [0, 1, 1, 1], so 3 unmasked after slicing) @@ -1124,18 +1204,9 @@ def test_clipped_pg_loss_gspo(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_clip = 0.2 - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.0, # Disable KL - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "sequence_level_importance_ratios": True, - "token_level_loss": False, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["sequence_level_importance_ratios"] = True + cfg["token_level_loss"] = False loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -1160,7 +1231,9 @@ def test_clipped_pg_loss_gspo(): ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 ) - ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + ratios_clamped = torch.clamp( + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] + ) assert torch.allclose( ratios_clamped, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), @@ -1211,18 +1284,9 @@ def test_clipped_pg_loss_gspo_batch_size_2(): batch_size=2, device=device ) - ratio_clip = 0.2 - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.0, # Disable KL - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "sequence_level_importance_ratios": True, - "token_level_loss": False, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["sequence_level_importance_ratios"] = True + cfg["token_level_loss"] = False loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0], [1.0, -1.0, 2.0]], device=device) @@ -1253,7 +1317,9 @@ def test_clipped_pg_loss_gspo_batch_size_2(): rtol=1e-3, ) - ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + ratios_clamped = torch.clamp( + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] + ) assert torch.allclose( ratios_clamped, torch.tensor([[0.9086, 0.9086, 0.9086], [1.2, 1.2, 1.2]], device=device), @@ -1316,18 +1382,10 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): device = "cuda" data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_clip = 0.2 - cfg = { - "ratio_clip_min": ratio_clip, - "ratio_clip_max": ratio_clip, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.0, # Disable KL - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": True, - "sequence_level_importance_ratios": True, - "token_level_loss": False, - } + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_importance_sampling_correction"] = True + cfg["sequence_level_importance_ratios"] = True + cfg["token_level_loss"] = False loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) @@ -1365,7 +1423,9 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): ratios, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), rtol=1e-3 ) - ratios_clamped = torch.clamp(ratios, 1.0 - ratio_clip, 1.0 + ratio_clip) + ratios_clamped = torch.clamp( + ratios, 1.0 - cfg["ratio_clip_min"], 1.0 + cfg["ratio_clip_max"] + ) assert torch.allclose( ratios_clamped, torch.tensor([[0.9086, 0.9086, 0.9086]], device=device), diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index 33d858fbe4..8ba8c9b65c 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -133,6 +133,8 @@ def test_sequence_packing_gradients(self): "ratio_clip_c": 3.0, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, "token_level_loss": True, } diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index c81ae15dcf..12691f97b0 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -674,6 +674,8 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, "token_level_loss": True, } ) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 48b2c01dc8..2c40d333f5 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -20,7 +20,12 @@ import torch from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ClippedPGLossFn, DPOLossFn, NLLLoss +from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossConfig, + ClippedPGLossFn, + DPOLossFn, + NLLLoss, +) from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster @@ -29,6 +34,19 @@ from nemo_rl.models.policy.lm_policy import Policy from tests.unit.test_utils import SimpleLoss +basic_pg_loss_test_config: ClippedPGLossConfig = { + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, + "token_level_loss": True, +} + def create_megatron_test_config( model_name: str, @@ -788,18 +806,7 @@ def test_megatron_loss_independent_of_microbatch_size(tiny_llama_model_path): # Test loss functions nll_loss_fn = NLLLoss() - pg_loss_fn = ClippedPGLossFn( - { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.1, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } - ) + pg_loss_fn = ClippedPGLossFn(basic_pg_loss_test_config) policy1.prepare_for_training() mbs1_nll_results = policy1.train(data, nll_loss_fn) @@ -1695,18 +1702,7 @@ def test_megatron_context_parallel_training_agreement(tiny_llama_model_path): ) # Create ClippedPG loss function - loss_fn = ClippedPGLossFn( - { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.1, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "token_level_loss": True, - } - ) + loss_fn = ClippedPGLossFn(basic_pg_loss_test_config) # Train non-CP model policy_no_cp.prepare_for_training()