diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 20b515c711..b2e6a45aa0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2071,6 +2071,19 @@ def unsloth_fast_generate( ): # If the model starts out in training mode, restore training mode after generation restore_training_mode = self.training + # why: snapshot the actual GC mode value (e.g. "unsloth") before for_inference + # clears it, so the post-generate restore preserves the caller's configured GC + # mode rather than collapsing it to a plain bool. + use_gradient_checkpointing = next( + ( + v + for v in ( + getattr(m, "gradient_checkpointing", False) for m in self.modules() + ) + if v + ), + False, + ) FastLlamaModel.for_inference(self) @@ -2156,7 +2169,10 @@ def unsloth_fast_generate( # pass if restore_training_mode: - FastLlamaModel.for_training(self) + FastLlamaModel.for_training( + self, + use_gradient_checkpointing = use_gradient_checkpointing, + ) return output diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8521b72602..7b7c3ac1a4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -186,6 +186,20 @@ def unwrap_model_for_generation( @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + # why: snapshot before TRL's unwrap context manager, which calls + # gradient_checkpointing_disable() before yielding; preserve the actual + # mode value (e.g. "unsloth") rather than collapsing it to a bool, so + # the finally restore matches the caller's configured GC mode. + use_gradient_checkpointing = next( + ( + v + for v in ( + getattr(m, "gradient_checkpointing", False) for m in model.modules() + ) + if v + ), + False, + ) with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(model) @@ -207,7 +221,10 @@ def generate_with_clone(*args, **kwargs): finally: # Restore generate and return unwrapped_model.generate = original_generate - FastLanguageModel.for_training(model) + FastLanguageModel.for_training( + model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) from transformers import Trainer from transformers.trainer_pt_utils import nested_detach diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 53ef25f95c..0f9a324d5b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -723,7 +723,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # Left pad prompt before calculation old and ref hidden states left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id) max_left_pad = torch.max(left_pad_tokens_per_prompt).item() - self.model.for_training()""" + self.model.for_training(use_gradient_checkpointing=getattr(self.args, 'gradient_checkpointing', True))""" function = function.replace(line_to_replace, replacement_lines)