-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Trl 0.27.0 update #3965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trl 0.27.0 update #3965
Changes from all commits
acf90f0
886664e
707fa38
cf24add
37c64bd
44beb30
48c3c19
c2a6fc1
b470807
e09f716
58a1fd4
d8a7c7d
f903be4
e7db06e
832c03b
26977f5
fa354fc
f6002a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||
| from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding | ||||||||||||||||||
| from unsloth_zoo.utils import Version | ||||||||||||||||||
| from trl import __version__ as trl_version_raw | ||||||||||||||||||
| from importlib.metadata import version as importlib_version | ||||||||||||||||||
| from unsloth_zoo.log import logger | ||||||||||||||||||
| from unsloth_zoo.device_type import device_synchronize | ||||||||||||||||||
|
|
@@ -57,6 +58,14 @@ | |||||||||||||||||
| "triton.cudagraphs": False, | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| try: | ||||||||||||||||||
| trl_version = Version(trl_version_raw) | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| try: | ||||||||||||||||||
| trl_version = Version(importlib_version("trl")) | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| trl_version = Version("0.0.0") | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # Check untrained tokens | ||||||||||||||||||
| def sft_trainer_fix_untrained_tokens(call_args, extra_args): | ||||||||||||||||||
|
|
@@ -434,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |||||||||||||||||
| _target_line + _metadata_extraction, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings. | ||||||||||||||||||
| # When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens | ||||||||||||||||||
| # already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError. | ||||||||||||||||||
| # Skipping it keeps prompts as strings so TRL uses the non-conversational path, which | ||||||||||||||||||
| # ensures completions are strings and reward functions work correctly. | ||||||||||||||||||
| string_to_find_vision = """ if images is not None: | ||||||||||||||||||
| prompts = [ | ||||||||||||||||||
| prepare_multimodal_messages(prompt, image_list) | ||||||||||||||||||
| for prompt, image_list in zip(prompts, images, strict=True) | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| replacement_string_vision = """ if images is not None: | ||||||||||||||||||
| # Unsloth: skip prepare_multimodal_messages for pre-templated string prompts | ||||||||||||||||||
| if not prompts or not isinstance(prompts[0], str): | ||||||||||||||||||
| prompts = [ | ||||||||||||||||||
| prepare_multimodal_messages(prompt, image_list) | ||||||||||||||||||
| for prompt, image_list in zip(prompts, images, strict=True) | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| function = function.replace(string_to_find_vision, replacement_string_vision) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated | ||||||||||||||||||
| # string prompts. When prompts are already strings (from notebooks that pre-applied | ||||||||||||||||||
| # apply_chat_template), calling it again crashes because strings aren't dicts. | ||||||||||||||||||
| # We use prompts directly as prompts_text instead. | ||||||||||||||||||
|
|
||||||||||||||||||
| # TRL 0.26.2+ variant (has tools=self.tools) | ||||||||||||||||||
| string_to_find_fwd = """ if images is not None: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template( | ||||||||||||||||||
| {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs | ||||||||||||||||||
| )["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| replacement_string_fwd = """ if images is not None: | ||||||||||||||||||
| # Unsloth: skip apply_chat_template for pre-templated string prompts | ||||||||||||||||||
| if prompts and isinstance(prompts[0], str): | ||||||||||||||||||
| prompts_text = prompts | ||||||||||||||||||
| else: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template( | ||||||||||||||||||
| {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs | ||||||||||||||||||
| )["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| function = function.replace(string_to_find_fwd, replacement_string_fwd) | ||||||||||||||||||
|
|
||||||||||||||||||
| # TRL 0.25.x variant (no tools parameter) | ||||||||||||||||||
| string_to_find_fwd_old = """ if images is not None: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template( | ||||||||||||||||||
| {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs | ||||||||||||||||||
| )["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| replacement_string_fwd_old = """ if images is not None: | ||||||||||||||||||
| # Unsloth: skip apply_chat_template for pre-templated string prompts | ||||||||||||||||||
| if prompts and isinstance(prompts[0], str): | ||||||||||||||||||
| prompts_text = prompts | ||||||||||||||||||
| else: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template( | ||||||||||||||||||
| {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs | ||||||||||||||||||
| )["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old) | ||||||||||||||||||
|
|
||||||||||||||||||
| # TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call) | ||||||||||||||||||
| string_to_find_fwd_single = """ if images is not None: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| replacement_string_fwd_single = """ if images is not None: | ||||||||||||||||||
| # Unsloth: skip apply_chat_template for pre-templated string prompts | ||||||||||||||||||
| if prompts and isinstance(prompts[0], str): | ||||||||||||||||||
| prompts_text = prompts | ||||||||||||||||||
| else: | ||||||||||||||||||
| prompts_text = [ | ||||||||||||||||||
| apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] | ||||||||||||||||||
| for prompt in prompts | ||||||||||||||||||
| ]""" | ||||||||||||||||||
|
|
||||||||||||||||||
| function = function.replace( | ||||||||||||||||||
| string_to_find_fwd_single, replacement_string_fwd_single | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # This path is for TRL 0.24.0 images is a variable exclusive to this version | ||||||||||||||||||
| string_to_find = """ if images is not None: | ||||||||||||||||||
| output["num_images"] = num_images""" | ||||||||||||||||||
|
|
@@ -543,6 +459,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function): | |||||||||||||||||
|
|
||||||||||||||||||
| function = function.replace(string_to_find, replacement_string) | ||||||||||||||||||
|
|
||||||||||||||||||
| if trl_version >= Version("0.25.0"): | ||||||||||||||||||
| # We replace the call using 'completions' with one using 'completions_text' | ||||||||||||||||||
| string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" | ||||||||||||||||||
| replacement_string = ( | ||||||||||||||||||
| " if images is not None:\n" | ||||||||||||||||||
| " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n" | ||||||||||||||||||
| " else:\n" | ||||||||||||||||||
| " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" | ||||||||||||||||||
| ) | ||||||||||||||||||
| function = function.replace(string_to_find, replacement_string) | ||||||||||||||||||
|
|
||||||||||||||||||
| if "wake_up()" not in function: | ||||||||||||||||||
| # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this. | ||||||||||||||||||
| # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709 | ||||||||||||||||||
|
|
@@ -1072,7 +999,7 @@ def compute_loss( | |||||||||||||||||
|
|
||||||||||||||||||
| max_left_pad = inputs.get("max_left_pad", 0) | ||||||||||||||||||
| if per_token_logps is not None: | ||||||||||||||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio = ( | ||||||||||||||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( | ||||||||||||||||||
| grpo_compute_loss_slow( | ||||||||||||||||||
| ref_logps, | ||||||||||||||||||
| per_token_logps, | ||||||||||||||||||
|
|
@@ -1102,7 +1029,7 @@ def compute_loss( | |||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| if hasattr(self.args, "loss_type"): | ||||||||||||||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio = ( | ||||||||||||||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( | ||||||||||||||||||
| grpo_accumulated_loss( | ||||||||||||||||||
| trainer = self, | ||||||||||||||||||
| input_ids = _input_ids, | ||||||||||||||||||
|
|
@@ -1134,7 +1061,7 @@ def compute_loss( | |||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 | ||||||||||||||||||
| loss, completion_length, mean_kl = grpo_accumulated_loss( | ||||||||||||||||||
| loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( | ||||||||||||||||||
| trainer = self, | ||||||||||||||||||
|
Comment on lines
1063
to
1065
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The backward-compatibility branch now unpacks four values from Useful? React with 👍 / 👎. |
||||||||||||||||||
| input_ids = _input_ids, | ||||||||||||||||||
| logits_to_keep = logits_to_keep, | ||||||||||||||||||
|
|
@@ -1149,7 +1076,6 @@ def compute_loss( | |||||||||||||||||
| logit_scale_divide = logit_scale_divide, | ||||||||||||||||||
| attention_mask = attention_mask, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if "train" in self._metrics: | ||||||||||||||||||
| mode = "eval" if self.control.should_evaluate else "train" | ||||||||||||||||||
| self._metrics[mode]["completion_length"].append(completion_length.item()) | ||||||||||||||||||
|
|
@@ -1211,6 +1137,53 @@ def compute_loss( | |||||||||||||||||
| .item() | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| completion_token_count = completion_mask.sum().clamp(min = 1.0) | ||||||||||||||||||
|
|
||||||||||||||||||
| def masked_batch_mean(x): | ||||||||||||||||||
| if x.shape[1] == 1: # when importance_sampling_level == "sequence" | ||||||||||||||||||
| return x.mean() | ||||||||||||||||||
| else: | ||||||||||||||||||
| return (x * completion_mask).sum() / completion_token_count | ||||||||||||||||||
|
|
||||||||||||||||||
| if advantages.dim() == 1: | ||||||||||||||||||
| advantages = advantages.unsqueeze(1) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: | ||||||||||||||||||
| # Compute the clipped probability ratios | ||||||||||||||||||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) | ||||||||||||||||||
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) | ||||||||||||||||||
| is_region_clipped = is_low_clipped | is_high_clipped | ||||||||||||||||||
|
|
||||||||||||||||||
| low_clip = masked_batch_mean(is_low_clipped.float()) | ||||||||||||||||||
| high_clip = masked_batch_mean(is_high_clipped.float()) | ||||||||||||||||||
| clip_ratio = masked_batch_mean(is_region_clipped.float()) | ||||||||||||||||||
|
|
||||||||||||||||||
| gathered_low_clip = self.accelerator.gather(low_clip) | ||||||||||||||||||
| self._metrics[mode]["clip_ratio/low_mean"].append( | ||||||||||||||||||
| gathered_low_clip.nanmean().item() | ||||||||||||||||||
|
Comment on lines
+1160
to
+1163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions
Suggested change
|
||||||||||||||||||
| ) | ||||||||||||||||||
| self._metrics[mode]["clip_ratio/low_min"].append( | ||||||||||||||||||
| nanmin(gathered_low_clip).item() | ||||||||||||||||||
| ) | ||||||||||||||||||
| gathered_high_clip = self.accelerator.gather(high_clip) | ||||||||||||||||||
| self._metrics[mode]["clip_ratio/high_mean"].append( | ||||||||||||||||||
| gathered_high_clip.nanmean().item() | ||||||||||||||||||
| ) | ||||||||||||||||||
| self._metrics[mode]["clip_ratio/high_max"].append( | ||||||||||||||||||
| nanmax(gathered_high_clip).item() | ||||||||||||||||||
| ) | ||||||||||||||||||
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) | ||||||||||||||||||
| self._metrics[mode]["clip_ratio/region_mean"].append( | ||||||||||||||||||
| gathered_clip_ratio.nanmean().item() | ||||||||||||||||||
| ) | ||||||||||||||||||
| elif self.loss_type == "cispo": | ||||||||||||||||||
| is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) | ||||||||||||||||||
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) | ||||||||||||||||||
| gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) | ||||||||||||||||||
| self._metrics[mode]["cispo_clip_ratio"].append( | ||||||||||||||||||
| gathered_cispo_clip_ratio.nanmean().item() | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| return loss | ||||||||||||||||||
|
|
||||||||||||||||||
| function = inspect.getsource(compute_loss) | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad
except Exception:can hide unexpected errors and make debugging harder. It's better to catch more specific exceptions. For example,Version()might raise aValueError, andimportlib_version()can raisePackageNotFoundError(which you'd need to import). Using more specific exceptions would make this code more robust and maintainable.