Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,19 @@ def unsloth_fast_generate(
):
# If the model starts out in training mode, restore training mode after generation
restore_training_mode = self.training
# why: snapshot the actual GC mode value (e.g. "unsloth") before for_inference
# clears it, so the post-generate restore preserves the caller's configured GC
# mode rather than collapsing it to a plain bool.
use_gradient_checkpointing = next(
(
v
for v in (
getattr(m, "gradient_checkpointing", False) for m in self.modules()
)
if v
),
False,
)

FastLlamaModel.for_inference(self)

Expand Down Expand Up @@ -2156,7 +2169,10 @@ def unsloth_fast_generate(
# pass

if restore_training_mode:
FastLlamaModel.for_training(self)
FastLlamaModel.for_training(
self,
use_gradient_checkpointing = use_gradient_checkpointing,
)

return output

Expand Down
19 changes: 18 additions & 1 deletion unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ def unwrap_model_for_generation(

@contextmanager
def unsloth_unwrap_model_for_generation(model, *args, **kwargs):
# why: snapshot before TRL's unwrap context manager, which calls
# gradient_checkpointing_disable() before yielding; preserve the actual
# mode value (e.g. "unsloth") rather than collapsing it to a bool, so
# the finally restore matches the caller's configured GC mode.
use_gradient_checkpointing = next(
(
v
for v in (
getattr(m, "gradient_checkpointing", False) for m in model.modules()
)
if v
),
False,
)
with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model:
# Put the model in inference mode.
FastLanguageModel.for_inference(model)
Expand All @@ -207,7 +221,10 @@ def generate_with_clone(*args, **kwargs):
finally:
# Restore generate and return
unwrapped_model.generate = original_generate
FastLanguageModel.for_training(model)
FastLanguageModel.for_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
)

from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
# Left pad prompt before calculation old and ref hidden states
left_pad_tokens_per_prompt = calculate_pad_tokens_in_prompt(prompt_completion_ids, logits_to_keep, self.processing_class.pad_token_id)
max_left_pad = torch.max(left_pad_tokens_per_prompt).item()
self.model.for_training()"""
self.model.for_training(use_gradient_checkpointing=getattr(self.args, 'gradient_checkpointing', True))"""

function = function.replace(line_to_replace, replacement_lines)

Expand Down
Loading