Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4080,6 +4080,10 @@ def replaced_tqdm(*args, **kwargs):

patch_torch_functions()

_conv_modules = frozenset([
"Conv1d", "Conv2d", "Conv3d",
"ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d",
])
for module in _patch_functions:
try:
source = eval(f"{model_location}.torch")
Expand All @@ -4096,14 +4100,38 @@ def replaced_tqdm(*args, **kwargs):
continue

source = inspect.getsource(function.forward).rstrip()

if module in _conv_modules:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness, consider using the dynamic parameter name detection for convolution modules as well. The current implementation hardcodes input, which is correct for standard torch.nn.Conv* layers, but applying the same dynamic approach you've used for normalization modules would make this code more resilient to custom convolution layers with different parameter names.

# Conv modules: cast input to weight dtype before the conv op,
# then cast output back to original input dtype. This prevents
# dtype mismatches under mixed-precision autocast (eg bf16
# weight + fp16 input crashes F.conv1d).
lines = source.split("\n")
def_line = lines[0]
body_lines = lines[1:]
first_body = next((l for l in body_lines if l.strip()), "")
body_indent = first_body[:len(first_body) - len(first_body.lstrip())]
prologue = [
body_indent + "original_dtype = input.dtype",
body_indent + "input = input.to(self.weight.dtype)",
]
source = "\n".join([def_line] + prologue + body_lines)
append_str = ".to(original_dtype)\n"
else:
# Norm modules: detect the actual parameter name (input or x)
import re as _re
m = _re.search(r"def forward\(self,\s*(\w+)", source)
Comment on lines +4122 to +4123
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The re module is already imported at the top of the file (line 28). Importing it again inside a loop is inefficient and unnecessary. You can use the module-level import directly.

Suggested change
import re as _re
m = _re.search(r"def forward\(self,\s*(\w+)", source)
# Use the top-level `re` import
m = re.search(r"def forward\(self,\s*(\w+)", source)

param_name = m.group(1) if m else "input"
append_str = f".to({param_name}.dtype)\n"

forward = create_new_function(
module,
source,
model_location,
functions,
prepend=_license_header
+ f"\ntorch_compile_options = {torch_compile_options}\n",
append=".to(input.dtype)\n",
append=append_str,
overwrite=False,
add_torch_compile=False,
).forward
Expand Down
Loading