diff --git a/pyproject.toml b/pyproject.toml index c504530d91..8e735775e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.3", + "unsloth_zoo>=2025.6.4", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.3", + "unsloth_zoo>=2025.6.4", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2de324ea76..889bbd4807 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -486,6 +486,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): arguments = re.sub(x, y, arguments) pass + # Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00! + # https://github.com/huggingface/trl/pull/3516 + # https://verl.readthedocs.io/en/latest/examples/config.html + if trainer_file == "grpo_trainer": + replacements = { + "beta" : 0.001, + } + for k, v in replacements.items(): + x = f"{k}( = [^,\n]{{1,}})?,\n" + y = f"'{v}'" if type(v) is str else f"{v}" + y = f"{k} = {y},\n" + arguments = re.sub(x, y, arguments) + pass + pass + # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ @@ -553,6 +568,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += check_num_generations pass + # Check temperature must not be <= 0. Also stop if >= 10 + if "temperature" in call_args: + check_temperature = \ + "if temperature <= 0:\n"\ + " raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\ + "elif temperature >= 10:\n"\ + " raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"\ + "\n" + extra_args += check_temperature + pass + # Edit config with anything extra if trainer_file in RL_CONFIG_CHANGES: process_extra_args = RL_CONFIG_CHANGES[trainer_file] diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index df95f73fc5..18f7720562 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -269,6 +269,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # See https://github.com/huggingface/trl/issues/2770 # logits = logits[:, -logits_to_keep:] # return logits + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # logits = logits / self.temperature # logps = selective_log_softmax(logits, input_ids) # row_indices, col_indices = torch.where(logps < -20) @@ -325,7 +327,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) @@ -335,10 +336,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch old_hidden_states = inputs["old_per_token_logps"] else: old_hidden_states = None + input_ids = input_ids[:, -logits_to_keep:] if per_token_logps is not None: - ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if ref_per_token_logps is not None: + ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss_slow( @@ -354,6 +358,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, delta = self.args.delta, + temperature = self.args.temperature, ) else: if hasattr(self.args, "loss_type"): @@ -370,6 +375,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, delta = self.args.delta, + temperature = self.args.temperature, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 @@ -381,6 +387,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch advantages, old_hidden_states, n_chunks = self.args.unsloth_num_chunks, + temperature = self.args.temperature, ) # Log the metrics