Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def filter(self, x): return not (self.text in x.getMessage())
"LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING", # Gemma3 create_masks_for_generate
"create_causal_mask(**mask_kwargs)", # Gemma3 create_masks_for_generate
"compute_mup_vector", # used in falcon h1 init and not needed to compile + inductor complains
"segment_sum", # falcon h1
"apply_mask_to_padding_states", # falcon h1
"reshape_into_chunks", # falcon h1
"pad_tensor_by_size", # falcon h1
]

_license_header = """
Expand Down Expand Up @@ -1784,8 +1788,7 @@ def compile_timm_models(UNSLOTH_ENABLE_LOGGING, torch_compile_options):
pass
pass


def compile_causal_conv1d():
def compile_causal_conv1d(UNSLOTH_ENABLE_LOGGING=False):
# For Liquid, Falcon and other Mamba type models
# We disable compiling on them!
try:
Expand All @@ -1794,8 +1797,42 @@ def compile_causal_conv1d():
torch.compiler.disable(causal_conv1d.causal_conv1d_fn, recursive = True)
causal_conv1d.causal_conv1d_update = \
torch.compiler.disable(causal_conv1d.causal_conv1d_update, recursive = True)
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Disabled compiling causal_conv1d")
return True
except Exception as e:
print(e, str(e))
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Failed compiling causal_conv1d")
return False
pass

def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False):
# For Liquid, Falcon and other Mamba type models
# We disable compiling on them!
try:
import mamba_ssm
mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined = \
torch.compiler.disable(
mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined,
recursive = True
)
mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined = \
torch.compiler.disable(
mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined,
recursive = True
)
mamba_ssm.ops.triton.selective_state_update.selective_state_update = \
torch.compiler.disable(
mamba_ssm.ops.triton.selective_state_update.selective_state_update,
recursive = True
)
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Disabled compiling mamba_ssm")
return True
except:
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Failed compiling mamba_ssm")
return False
pass

Expand Down Expand Up @@ -1892,7 +1929,8 @@ def replaced_tqdm(*args, **kwargs):
compile_timm_models(UNSLOTH_ENABLE_LOGGING, torch_compile_options)

# Disable compiling mamba type models
has_causal_conv1d = compile_causal_conv1d()
has_causal_conv1d = compile_causal_conv1d(UNSLOTH_ENABLE_LOGGING)
has_mamba_ssm = compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING)

# Return logits
UNSLOTH_RETURN_LOGITS = "0" if not return_logits else "1"
Expand Down Expand Up @@ -1936,6 +1974,17 @@ def replaced_tqdm(*args, **kwargs):
)
pass

# If mamba type, but no fast causal functions, warn!
if not has_mamba_ssm and \
("mamba_chunk_scan_combined" in full_source or "mamba_split_conv1d_scan_combined" in full_source or "selective_state_update" in full_source):
print(
"**********\n"\
"Unsloth: Please install `mamba_ssm` to speed up Mamba training via `pip install mamba_ssm`\n"\
"If you don't, training will still work, just might be slower for Mamba type models.\n"\
"**********\n"
)
pass

# Get class LlamaAttention(nn.Module)
torch_modules = re.findall(r"class ([^\s]{1,})\(.+?\.Module\)", full_source)
# Also get class LlamaSdpaAttention(LlamaAttention)
Expand Down