From 4189d468972ccb3016d04e36eaacac62632b151e Mon Sep 17 00:00:00 2001 From: Mathew Mathew Date: Fri, 18 Jul 2025 13:20:30 -0500 Subject: [PATCH] Falcon H1 training is fp16 is unstable with the mamba kernels. NaN's appear frequently during training. To handle this situation we can force float32 when the dtype is float 16. --- unsloth_zoo/compiler.py | 55 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c54b66ed9..b27b8d9e9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -98,6 +98,10 @@ def filter(self, x): return not (self.text in x.getMessage()) "LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING", # Gemma3 create_masks_for_generate "create_causal_mask(**mask_kwargs)", # Gemma3 create_masks_for_generate "compute_mup_vector", # used in falcon h1 init and not needed to compile + inductor complains + "segment_sum", # falcon h1 + "apply_mask_to_padding_states", # falcon h1 + "reshape_into_chunks", # falcon h1 + "pad_tensor_by_size", # falcon h1 ] _license_header = """ @@ -1784,8 +1788,7 @@ def compile_timm_models(UNSLOTH_ENABLE_LOGGING, torch_compile_options): pass pass - -def compile_causal_conv1d(): +def compile_causal_conv1d(UNSLOTH_ENABLE_LOGGING=False): # For Liquid, Falcon and other Mamba type models # We disable compiling on them! try: @@ -1794,8 +1797,42 @@ def compile_causal_conv1d(): torch.compiler.disable(causal_conv1d.causal_conv1d_fn, recursive = True) causal_conv1d.causal_conv1d_update = \ torch.compiler.disable(causal_conv1d.causal_conv1d_update, recursive = True) + if UNSLOTH_ENABLE_LOGGING: + print(f"Unsloth: Disabled compiling causal_conv1d") + return True + except Exception as e: + print(e, str(e)) + if UNSLOTH_ENABLE_LOGGING: + print(f"Unsloth: Failed compiling causal_conv1d") + return False +pass + +def compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING=False): + # For Liquid, Falcon and other Mamba type models + # We disable compiling on them! + try: + import mamba_ssm + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined = \ + torch.compiler.disable( + mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined, + recursive = True + ) + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined = \ + torch.compiler.disable( + mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined, + recursive = True + ) + mamba_ssm.ops.triton.selective_state_update.selective_state_update = \ + torch.compiler.disable( + mamba_ssm.ops.triton.selective_state_update.selective_state_update, + recursive = True + ) + if UNSLOTH_ENABLE_LOGGING: + print(f"Unsloth: Disabled compiling mamba_ssm") return True except: + if UNSLOTH_ENABLE_LOGGING: + print(f"Unsloth: Failed compiling mamba_ssm") return False pass @@ -1892,7 +1929,8 @@ def replaced_tqdm(*args, **kwargs): compile_timm_models(UNSLOTH_ENABLE_LOGGING, torch_compile_options) # Disable compiling mamba type models - has_causal_conv1d = compile_causal_conv1d() + has_causal_conv1d = compile_causal_conv1d(UNSLOTH_ENABLE_LOGGING) + has_mamba_ssm = compile_mamba_ssm(UNSLOTH_ENABLE_LOGGING) # Return logits UNSLOTH_RETURN_LOGITS = "0" if not return_logits else "1" @@ -1936,6 +1974,17 @@ def replaced_tqdm(*args, **kwargs): ) pass + # If mamba type, but no fast causal functions, warn! + if not has_mamba_ssm and \ + ("mamba_chunk_scan_combined" in full_source or "mamba_split_conv1d_scan_combined" in full_source or "selective_state_update" in full_source): + print( + "**********\n"\ + "Unsloth: Please install `mamba_ssm` to speed up Mamba training via `pip install mamba_ssm`\n"\ + "If you don't, training will still work, just might be slower for Mamba type models.\n"\ + "**********\n" + ) + pass + # Get class LlamaAttention(nn.Module) torch_modules = re.findall(r"class ([^\s]{1,})\(.+?\.Module\)", full_source) # Also get class LlamaSdpaAttention(LlamaAttention)