Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ def __init__(self, config: JambaConfig, layer_idx):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_fast_kernels = config.use_mamba_kernels

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
# selective projection used to make dt, B and C input dependent
Expand Down Expand Up @@ -369,8 +367,7 @@ def __init__(self, config: JambaConfig, layer_idx):
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
" is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d."
)

def cuda_kernels_forward(
Expand Down Expand Up @@ -571,11 +568,17 @@ def forward(
cache_params: HybridMambaAttentionDynamicCache | None = None,
attention_mask: torch.LongTensor | None = None,
):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
)
if self.config.use_mamba_kernels and (
not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type
):
logger.warning_once(
"Fast Mamba kernels are not available. Make sure that they are installed "
"and that the mamba module is on a CUDA device. Turning off the fast path "
"`config.use_mamba_kernels=False` and falling back to the slow path."
)
self.config.use_mamba_kernels = False

if self.config.use_mamba_kernels:
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask)

Expand Down
21 changes: 12 additions & 9 deletions src/transformers/models/jamba/modular_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ def __init__(self, config: JambaConfig, layer_idx):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_fast_kernels = config.use_mamba_kernels

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
# selective projection used to make dt, B and C input dependent
Expand Down Expand Up @@ -261,8 +259,7 @@ def __init__(self, config: JambaConfig, layer_idx):
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
" is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d."
)

def cuda_kernels_forward(
Expand Down Expand Up @@ -463,11 +460,17 @@ def forward(
cache_params: HybridMambaAttentionDynamicCache | None = None,
attention_mask: torch.LongTensor | None = None,
):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
)
if self.config.use_mamba_kernels and (
not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type
):
logger.warning_once(
"Fast Mamba kernels are not available. Make sure that they are installed "
"and that the mamba module is on a CUDA device. Turning off the fast path "
"`config.use_mamba_kernels=False` and falling back to the slow path."
)
self.config.use_mamba_kernels = False

if self.config.use_mamba_kernels:
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask)

Expand Down