diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 2191629f1a2d..ed1c4cbfc312 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -328,8 +328,6 @@ def __init__(self, config: JambaConfig, layer_idx): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.use_fast_kernels = config.use_mamba_kernels - # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) # selective projection used to make dt, B and C input dependent @@ -369,8 +367,7 @@ def __init__(self, config: JambaConfig, layer_idx): if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d." ) def cuda_kernels_forward( @@ -571,11 +568,17 @@ def forward( cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None, ): - if self.use_fast_kernels: - if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: - raise ValueError( - "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device" - ) + if self.config.use_mamba_kernels and ( + not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type + ): + logger.warning_once( + "Fast Mamba kernels are not available. Make sure that they are installed " + "and that the mamba module is on a CUDA device. Turning off the fast path " + "`config.use_mamba_kernels=False` and falling back to the slow path." + ) + self.config.use_mamba_kernels = False + + if self.config.use_mamba_kernels: return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) return self.slow_forward(hidden_states, cache_params, attention_mask) diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index e5aeeba56627..f395136837ec 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -220,8 +220,6 @@ def __init__(self, config: JambaConfig, layer_idx): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] - self.use_fast_kernels = config.use_mamba_kernels - # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias) # selective projection used to make dt, B and C input dependent @@ -261,8 +259,7 @@ def __init__(self, config: JambaConfig, layer_idx): if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config" + " is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d." ) def cuda_kernels_forward( @@ -463,11 +460,17 @@ def forward( cache_params: HybridMambaAttentionDynamicCache | None = None, attention_mask: torch.LongTensor | None = None, ): - if self.use_fast_kernels: - if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: - raise ValueError( - "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device" - ) + if self.config.use_mamba_kernels and ( + not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type + ): + logger.warning_once( + "Fast Mamba kernels are not available. Make sure that they are installed " + "and that the mamba module is on a CUDA device. Turning off the fast path " + "`config.use_mamba_kernels=False` and falling back to the slow path." + ) + self.config.use_mamba_kernels = False + + if self.config.use_mamba_kernels: return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) return self.slow_forward(hidden_states, cache_params, attention_mask)