-
-
Notifications
You must be signed in to change notification settings - Fork 6k
[trl] Trl v0.28 (and above) rl fixes #4156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
cb6fb1b
a863d0f
690fecf
35721ee
c75d8aa
785e308
deafd91
50618b3
5bc60d6
0af28ae
f7e7e82
dc55ac7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |||||
| import re | ||||||
| import torch | ||||||
| import inspect | ||||||
| import linecache | ||||||
| from collections import defaultdict | ||||||
| from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding | ||||||
| from unsloth_zoo.utils import Version | ||||||
|
|
@@ -264,12 +265,30 @@ def grpo_trainer__generate_single_turn(function_name, function): | |||||
| # Remove the reload_weights collective RPC call from the generate function's source | ||||||
| # function = function.replace('self.llm.collective_rpc("reload_weights")', "") | ||||||
| # The regex below does the same thing but is more flexible and can handle single or double quotes | ||||||
| # This is for older versions. | ||||||
| function = re.sub( | ||||||
| r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)", | ||||||
| "", | ||||||
| function, | ||||||
| ) | ||||||
|
|
||||||
| # Current TRL versions call vllm_generation.sync_weights() every step. | ||||||
| # When Unsloth fast inference LoRA is active, weights are already shared. | ||||||
| sync_weights_block = re.compile( | ||||||
| r"(?P<indent>[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n" | ||||||
| r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def remove_sync_weights_block(match): | ||||||
| indent = match.group("indent") | ||||||
| return ( | ||||||
| f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" | ||||||
| f"{indent}# Skipping per-step vLLM sync_weights().\n" | ||||||
| ) | ||||||
|
|
||||||
| function = sync_weights_block.sub(remove_sync_weights_block, function) | ||||||
|
|
||||||
| # TRL 0.24.0-0.25.1 truncation regression fix | ||||||
| # | ||||||
| # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens(): | ||||||
|
|
@@ -1013,17 +1032,61 @@ def compute_loss( | |||||
|
|
||||||
| max_left_pad = inputs.get("max_left_pad", 0) | ||||||
| if per_token_logps is not None: | ||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( | ||||||
| grpo_compute_loss_slow( | ||||||
| ref_logps, | ||||||
| per_token_logps, | ||||||
| old_logps, | ||||||
| input_ids, | ||||||
| ( | ||||||
| loss, | ||||||
| completion_length, | ||||||
| mean_kl, | ||||||
| delta, | ||||||
| flat_is_ratio, | ||||||
| coef_1, | ||||||
| completion_mask, | ||||||
| ) = grpo_compute_loss_slow( | ||||||
| ref_logps, | ||||||
| per_token_logps, | ||||||
| old_logps, | ||||||
| input_ids, | ||||||
| completion_mask, | ||||||
| self.beta, | ||||||
| advantages, | ||||||
| pixel_values = pixel_values, | ||||||
| image_grid_thw = image_grid_thw, | ||||||
| loss_type = self.args.loss_type, | ||||||
| importance_sampling_level = self.importance_sampling_level, | ||||||
| epsilon_low = self.epsilon_low, | ||||||
| epsilon_high = self.epsilon_high, | ||||||
| max_completion_length = self.args.max_completion_length, | ||||||
| delta = self.args.delta, | ||||||
| temperature = self.args.temperature, | ||||||
| max_left_pad = max_left_pad, | ||||||
| logit_softcapping = logit_softcapping, | ||||||
| logit_scale_multiply = logit_scale_multiply, | ||||||
| logit_scale_divide = logit_scale_divide, | ||||||
| num_items_in_batch = num_items_in_batch, | ||||||
| current_gradient_accumulation_steps = current_gradient_accumulation_steps, | ||||||
| num_processes = num_processes, | ||||||
| sampling_per_token_logps = sampling_per_token_logps, | ||||||
| ) | ||||||
| else: | ||||||
| if hasattr(self.args, "loss_type"): | ||||||
| ( | ||||||
| loss, | ||||||
| completion_length, | ||||||
| mean_kl, | ||||||
| delta, | ||||||
| flat_is_ratio, | ||||||
| coef_1, | ||||||
| completion_mask, | ||||||
| self.beta, | ||||||
| advantages, | ||||||
| ) = grpo_accumulated_loss( | ||||||
| trainer = self, | ||||||
| input_ids = _input_ids, | ||||||
| pixel_values = pixel_values, | ||||||
| image_grid_thw = image_grid_thw, | ||||||
| logits_to_keep = logits_to_keep, | ||||||
| completion_mask = completion_mask, | ||||||
| advantages = advantages, | ||||||
| old_logps = old_logps, | ||||||
| ref_logps = ref_logps, | ||||||
| n_chunks = self.args.unsloth_num_chunks, | ||||||
| loss_type = self.args.loss_type, | ||||||
| importance_sampling_level = self.importance_sampling_level, | ||||||
| epsilon_low = self.epsilon_low, | ||||||
|
|
@@ -1035,61 +1098,31 @@ def compute_loss( | |||||
| logit_softcapping = logit_softcapping, | ||||||
| logit_scale_multiply = logit_scale_multiply, | ||||||
| logit_scale_divide = logit_scale_divide, | ||||||
| attention_mask = attention_mask, | ||||||
| num_items_in_batch = num_items_in_batch, | ||||||
| current_gradient_accumulation_steps = current_gradient_accumulation_steps, | ||||||
| num_processes = num_processes, | ||||||
| sampling_per_token_logps = sampling_per_token_logps, | ||||||
| ) | ||||||
| ) | ||||||
| else: | ||||||
| if hasattr(self.args, "loss_type"): | ||||||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( | ||||||
| else: | ||||||
| # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 | ||||||
| loss, completion_length, mean_kl, coef_1, completion_mask = ( | ||||||
| grpo_accumulated_loss( | ||||||
| trainer = self, | ||||||
| input_ids = _input_ids, | ||||||
| pixel_values = pixel_values, | ||||||
| image_grid_thw = image_grid_thw, | ||||||
| logits_to_keep = logits_to_keep, | ||||||
| completion_mask = completion_mask, | ||||||
| advantages = advantages, | ||||||
| old_logps = old_logps, | ||||||
| ref_logps = ref_logps, | ||||||
| n_chunks = self.args.unsloth_num_chunks, | ||||||
| loss_type = self.args.loss_type, | ||||||
| importance_sampling_level = self.importance_sampling_level, | ||||||
| epsilon_low = self.epsilon_low, | ||||||
| epsilon_high = self.epsilon_high, | ||||||
| max_completion_length = self.args.max_completion_length, | ||||||
| delta = self.args.delta, | ||||||
| temperature = self.args.temperature, | ||||||
| max_left_pad = max_left_pad, | ||||||
| logit_softcapping = logit_softcapping, | ||||||
| logit_scale_multiply = logit_scale_multiply, | ||||||
| logit_scale_divide = logit_scale_divide, | ||||||
| attention_mask = attention_mask, | ||||||
| num_items_in_batch = num_items_in_batch, | ||||||
| current_gradient_accumulation_steps = current_gradient_accumulation_steps, | ||||||
| num_processes = num_processes, | ||||||
| sampling_per_token_logps = sampling_per_token_logps, | ||||||
| ) | ||||||
| ) | ||||||
| else: | ||||||
| # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 | ||||||
| loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( | ||||||
| trainer = self, | ||||||
| input_ids = _input_ids, | ||||||
| logits_to_keep = logits_to_keep, | ||||||
| completion_mask = completion_mask, | ||||||
| advantages = advantages, | ||||||
| old_logps = old_logps, | ||||||
| ref_logps = ref_logps, | ||||||
| n_chunks = self.args.unsloth_num_chunks, | ||||||
| temperature = self.args.temperature, | ||||||
| logit_softcapping = logit_softcapping, | ||||||
| logit_scale_multiply = logit_scale_multiply, | ||||||
| logit_scale_divide = logit_scale_divide, | ||||||
| attention_mask = attention_mask, | ||||||
| ) | ||||||
| if "train" in self._metrics: | ||||||
| mode = "eval" if self.control.should_evaluate else "train" | ||||||
| self._metrics[mode]["completion_length"].append(completion_length.item()) | ||||||
|
|
@@ -1338,3 +1371,157 @@ def openenv_vllm_reload_weights(): | |||||
|
|
||||||
|
|
||||||
| RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) | ||||||
|
|
||||||
|
|
||||||
| def vllm_generation_init_patch(): | ||||||
| # trl moved vllm stuff to trl/generation/vllm_generation.py | ||||||
| # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference | ||||||
| # Edit the TRL source directly and install the patched function in the TRL module. | ||||||
| # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16 | ||||||
| # This exists in trl versions 0.28.0 and above | ||||||
|
|
||||||
| if importlib.util.find_spec("trl") is None: | ||||||
| return | ||||||
| if Version(importlib_version("trl")) < Version("0.28.0"): | ||||||
| return | ||||||
|
|
||||||
| try: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The exception handling
Suggested change
|
||||||
| import trl.generation.vllm_generation as vllm_generation | ||||||
| except (ImportError, NameError, Exception) as e: | ||||||
| logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}") | ||||||
| return | ||||||
|
|
||||||
| def patch_vllm_generation_method(method_name, transform, marker, filename_suffix): | ||||||
| method = getattr(vllm_generation.VLLMGeneration, method_name, None) | ||||||
| if method is None: | ||||||
| logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}") | ||||||
| return False | ||||||
|
|
||||||
| try: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching the broad
Suggested change
|
||||||
| src = inspect.getsource(method) | ||||||
| except Exception as e: | ||||||
| logger.info( | ||||||
| f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}" | ||||||
| ) | ||||||
| return False | ||||||
|
|
||||||
| src = textwrap.dedent(src) | ||||||
| if marker in src: | ||||||
| return True | ||||||
|
|
||||||
| src = transform(src) | ||||||
| filename = f"<unsloth_trl_vllm_generation_{filename_suffix}_patch>" | ||||||
| source_lines = [line + "\n" for line in src.splitlines()] | ||||||
| linecache.cache[filename] = ( | ||||||
| len(src), | ||||||
| None, | ||||||
| source_lines, | ||||||
| filename, | ||||||
| ) | ||||||
|
|
||||||
| local_ns = {} | ||||||
| exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns) | ||||||
| setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name]) | ||||||
| return True | ||||||
|
|
||||||
| # Patch init to remove vLLM.LLM instantiation | ||||||
| def patch_init_vllm(src): | ||||||
| pattern = re.compile( | ||||||
| r"(?P<llm_block>^(?P<indent>[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_llm_block(match): | ||||||
| indent = match.group("indent") | ||||||
| llm_block = textwrap.dedent(match.group("llm_block")) | ||||||
| return ( | ||||||
| f"{indent}if hasattr(model, 'vllm_engine'):\n" | ||||||
| f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" | ||||||
| f"{indent} self.llm = model.vllm_engine\n" | ||||||
| f"{indent} self.unsloth_fast_inference_lora = True\n" | ||||||
| f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ") | ||||||
| ) | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| # has some sync_weights or reload rpc calls. | ||||||
| # we patched the grpo_trainer to strip them for prev versions | ||||||
| # Ref: grpo_trainer__generate_single_turn above around L270-280 | ||||||
| def patch_sync_weights(src): | ||||||
| pattern = re.compile( | ||||||
| r"^(?P<def_line>def sync_weights\(self\):\n)(?P<body>(?:.*\n)*)", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_sync_weights(match): | ||||||
| body = match.group("body") | ||||||
| guard = ( | ||||||
| " if getattr(self, 'unsloth_fast_inference_lora', False):\n" | ||||||
| " # Unsloth fast inference LoRA shares weights with vLLM already.\n" | ||||||
| " return\n\n" | ||||||
| ) | ||||||
| return match.group("def_line") + guard + body | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| def patch_generate(src): | ||||||
| pattern = re.compile( | ||||||
| r"^(?P<indent>[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_reload_weights(match): | ||||||
| indent = match.group("indent") | ||||||
| return f'{indent}pass # self.llm.collective_rpc("reload_weights")' | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn( | ||||||
| replace_reload_weights, src, count = 1 | ||||||
| ) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| try: | ||||||
| init_patched = patch_vllm_generation_method( | ||||||
| "_init_vllm", | ||||||
| patch_init_vllm, | ||||||
| "self.unsloth_fast_inference_lora = True", | ||||||
| "init_vllm", | ||||||
| ) | ||||||
| sync_patched = patch_vllm_generation_method( | ||||||
| "sync_weights", | ||||||
| patch_sync_weights, | ||||||
| "if getattr(self, 'unsloth_fast_inference_lora', False):", | ||||||
| "sync_weights", | ||||||
| ) | ||||||
| generate_patched = patch_vllm_generation_method( | ||||||
| "generate", | ||||||
| patch_generate, | ||||||
| 'pass # self.llm.collective_rpc("reload_weights")', | ||||||
| "generate", | ||||||
| ) | ||||||
| except RuntimeError as e: | ||||||
| logger.warning(str(e)) | ||||||
| return | ||||||
|
|
||||||
| if init_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm") | ||||||
| if sync_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights") | ||||||
| if generate_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration.generate") | ||||||
|
|
||||||
|
|
||||||
| RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change removes the
self.vllm_generation.sync_weights()call from_generate_single_turnfor all GRPO executions, not just when Unsloth fast-inference LoRA is active. In the non-fast-inference path (where_init_vllmcreates a separateLLM(...)instance), that sync is what keeps vLLM generation weights aligned with the training model each step; skipping it causes rollouts to be generated from stale parameters, which can corrupt training signals. The laterunsloth_fast_inference_loraguard inVLLMGeneration.sync_weightsdoes not mitigate this because the call site is removed entirely.Useful? React with 👍 / 👎.