diff --git a/pyproject.toml b/pyproject.toml index 8e735775e7..51c540c8e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,10 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.4", + "unsloth_zoo>=2025.6.5", "packaging", "tyro", - "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", @@ -381,10 +381,10 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.4", + "unsloth_zoo>=2025.6.5", "packaging", "tyro", - "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 05c05ae9ff..e03f50baa9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.6.5" +__version__ = "2025.6.6" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b96302c5f7..9b0f4e4aef 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -247,10 +247,10 @@ def _move_model_to_vllm(self, *args, **kwargs): return None # Edit _get_per_token_logps to handle mixed precision def grpo_trainer__get_per_token_logps(function_name, function): - if function_name != "_get_per_token_logps": return function + if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): @@ -260,9 +260,13 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits - #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - return hidden_states + logits = model( + input_ids = input_ids, + attention_mask = attention_mask, + logits_to_keep = logits_to_keep + 1, + ).logits + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + return logits # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 @@ -331,19 +335,24 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - if "old_per_token_logps" in inputs.keys(): - old_hidden_states = inputs["old_per_token_logps"] - else: - old_hidden_states = None - + old_hidden_states = inputs.get("old_per_token_logps", None) input_ids = input_ids[:, -logits_to_keep:] + + # Get logit softcapping and logit scale + logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma + if logit_softcapping is None: logit_softcapping = 0 + logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere + if logit_scale_multiply is None: logit_scale_multiply = 0 + logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite + if logit_scale_divide is None: logit_scale_divide = 0 + + if per_token_logps is not None: if ref_per_token_logps is not None: ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - + loss, completion_length, mean_kl = grpo_compute_loss_slow( ref_per_token_logps, per_token_logps, @@ -358,16 +367,19 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch max_completion_length = self.args.max_completion_length, delta = self.args.delta, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, ) else: if hasattr(self.args, "loss_type"): loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + trainer = self, + input_ids = _input_ids, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_hidden_states = old_hidden_states, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, @@ -375,26 +387,33 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch max_completion_length = self.args.max_completion_length, delta = self.args.delta, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + trainer = self, + input_ids = _input_ids, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_hidden_states = old_hidden_states, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, ) - + pass + pass # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() - # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) - if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a140ccdfe0..b0f485d90e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -720,6 +720,8 @@ def _for_inference(m): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass + # Must disable returning hidden states in the case for GRPO + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" return model pass