Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def pre_patch():
return

@staticmethod
def post_patch(model, tokenizer):
def post_patch(model, tokenizer, correct_dtype = None):
# Gemma does not downcast RoPE
model, tokenizer = patch_model_and_tokenizer(
model, tokenizer, downcast_rope = False
model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype
)

# Add 1 to weight
Expand Down
4 changes: 2 additions & 2 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ def pre_patch():
return

@staticmethod
def post_patch(model, tokenizer):
def post_patch(model, tokenizer, correct_dtype = None):
# Gemma does not downcast RoPE
model, tokenizer = patch_model_and_tokenizer(
model, tokenizer, downcast_rope = False
model, tokenizer, downcast_rope = False, correct_dtype = correct_dtype
)

# Add 1 to weight
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def pre_patch():
return

@staticmethod
def post_patch(model, tokenizer):
def post_patch(model, tokenizer, correct_dtype = None):
# Torch.compile fails on embedding matrix??
# Workaround randomnly fixes it for torch versions < 2.2
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(
Expand Down
8 changes: 5 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,7 +2483,9 @@ def from_pretrained(
)

model, tokenizer = patch_tokenizer(model, tokenizer)
model, tokenizer = model_patcher.post_patch(model, tokenizer)
model, tokenizer = model_patcher.post_patch(
model, tokenizer, correct_dtype = dtype
)

# Patch up QKV / O and MLP
for idx, layer in enumerate(model.model.layers):
Expand Down Expand Up @@ -2666,9 +2668,9 @@ def from_pretrained(
return model, tokenizer

@staticmethod
def post_patch(model, tokenizer):
def post_patch(model, tokenizer, correct_dtype = None):
model, tokenizer = patch_model_and_tokenizer(
model, tokenizer, downcast_rope = True
model, tokenizer, downcast_rope = True, correct_dtype = correct_dtype
)
return model, tokenizer

Expand Down
12 changes: 12 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
flags = re.DOTALL,
)

# Remove TRL's unconditional bfloat16 cast of trainable params (added in
# TRL 0.26.0). TRL hardcodes bfloat16 for QLoRA per the original paper's
# recommendation, but this is wrong: it ignores the user's requested dtype
# and breaks GradScaler when training with fp16=True. Unsloth already
# handles adapter dtype correctly via patch_model_and_tokenizer, so the
# entire block is unnecessary. For GRPOTrainer the enclosing peft init
# block is already removed above, making this a no-op for GRPO.
RLTrainer_source = RLTrainer_source.replace(
'if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):',
"if False:",
)
Comment on lines +1254 to +1257

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using str.replace for monkey-patching can be brittle. It will fail silently if the source string in TRL changes even slightly (e.g., extra whitespace). Using re.sub with a pattern that allows for flexible whitespace would make this patch more robust against future changes in the TRL library.

Suggested change
RLTrainer_source = RLTrainer_source.replace(
"param.data = param.data.to(torch.bfloat16)",
"param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)",
)
RLTrainer_source = re.sub(
r"param\.data\s*=\s*param\.data\.to\(torch\.bfloat16\)",
"param.data = param.data.to(torch.float32 if getattr(args, 'fp16', False) else torch.bfloat16)",
RLTrainer_source,
)


if RLTrainer_name == "SFTTrainer":
original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'
Expand Down