Skip to content
32 changes: 30 additions & 2 deletions docs/source/en/model_doc/bamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
<!---
## Usage Tips

Tips:
Tips:

- The architecture is based on Mamba-2 models.

Expand All @@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```


## Padding-Free Training

Bamba supports padding-free training in which distinct training examples can be concatenated
together while nevertheless processing the inputs as though they belonged to separate batches. When
the examples are of varying lengths, padding-free training can provide significant speed ups and
memory savings compared to batching the examples together and using padding, as the unnecessary
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
as the model and the data distribution, but throughput gains up to [~2x are commonly
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).

Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
packages, and the following arguments must be passed to the model in addition to `input_ids` and
`labels`:
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
* Each of the [`FlashAttentionKwargs`]
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
* `max_length_q: int`: the longest query length in the batch.
* `max_length_k: int`: the longest key length in the batch.

The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
for additional information.


[[autodoc]] BambaForCausalLM
- forward

This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
54 changes: 45 additions & 9 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Tuple, Union
from functools import partial
from typing import Callable, Optional, Tuple, TypedDict, Union

import torch
from torch import nn
Expand Down Expand Up @@ -61,6 +62,31 @@
logger = logging.get_logger(__name__)


class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
"""
Expand Down Expand Up @@ -487,6 +513,7 @@ def cuda_kernels_forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
Expand Down Expand Up @@ -569,7 +596,7 @@ def cuda_kernels_forward(
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
Expand Down Expand Up @@ -610,6 +637,7 @@ def cuda_kernels_forward(
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)

hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
Expand All @@ -629,7 +657,7 @@ def cuda_kernels_forward(
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
Expand Down Expand Up @@ -863,9 +891,15 @@ def forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob unpack here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpack the **kwargs? These **kwargs are just a catch-all for the non-seq_idx kwargs in BambaFlashAttentionKwargs because the BambaMixer layer only uses seq_idx, while BambaAttention uses the rest.

Do you want me to do **kwargs: Unpack[FlashAttentionKwargs] even though the kwargs are unused?

):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
Expand Down Expand Up @@ -939,7 +973,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -959,8 +993,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""

residual = hidden_states
Expand All @@ -974,6 +1008,7 @@ def forward(
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
**kwargs,
)
self_attn_weights = None
elif self.layer_type == "attention":
Expand Down Expand Up @@ -1076,7 +1111,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1128,7 +1163,7 @@ def forward(

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
Expand All @@ -1148,6 +1183,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
61 changes: 51 additions & 10 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
# limitations under the License.
"""PyTorch Bamba model."""

from typing import Optional, Tuple, Union
from functools import partial
from typing import Optional, Tuple, TypedDict, Union

import torch
import torch.utils.checkpoint
Expand All @@ -46,7 +47,12 @@
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, logging
from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig

Expand All @@ -71,6 +77,31 @@
logger = logging.get_logger(__name__)


class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.

Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""

cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor


# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
"""
Expand Down Expand Up @@ -278,6 +309,7 @@ def cuda_kernels_forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
Expand Down Expand Up @@ -360,7 +392,7 @@ def cuda_kernels_forward(
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
Expand Down Expand Up @@ -401,6 +433,7 @@ def cuda_kernels_forward(
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)

hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
Expand All @@ -420,7 +453,7 @@ def cuda_kernels_forward(
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
Expand Down Expand Up @@ -654,9 +687,15 @@ def forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
Expand Down Expand Up @@ -701,7 +740,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -721,8 +760,8 @@ def forward(
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""

residual = hidden_states
Expand All @@ -736,6 +775,7 @@ def forward(
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
**kwargs,
)
self_attn_weights = None
elif self.layer_type == "attention":
Expand Down Expand Up @@ -838,7 +878,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -890,7 +930,7 @@ def forward(

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
Expand All @@ -910,6 +950,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,

Unnecessary?

)

hidden_states = layer_outputs[0]
Expand Down
Loading