Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _torch_cached_ssm_transform_fake(
return torch.empty_like(
hidden_states,
memory_format=torch.contiguous_format,
dtype=torch.float32,
)


Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/custom_ops/torch_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,4 @@ def _torch_ssm_transform_meta(
time_step_limit: List[float],
chunk_size: int,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
return torch.empty_like(hidden_states, dtype=torch.float32)
15 changes: 15 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ def _bamba_model_update_mamba_mask(self, attention_mask, cache_position):
return None


def _bamba_model_update_causal_mask(
self,
attention_mask,
input_tensor,
cache_position,
past_key_values,
output_attentions,
):
# Force attention to use causal mode without explicit masks
return None


# NOTE: this would need to be applied earlier than other patches, since the `_init_weights` (which
# is called by `post_init`) is called before we run `forward`.
def _bamba_pretrained_model_init_weights(self, module):
Expand Down Expand Up @@ -182,17 +194,20 @@ class BambaModelPatch(BaseExportPatch):
def _apply_patch(self):
self.original_values["BambaMixer.torch_forward"] = BambaMixer.torch_forward
self.original_values["BambaModel._update_mamba_mask"] = BambaModel._update_mamba_mask
self.original_values["BambaModel._update_causal_mask"] = BambaModel._update_causal_mask
# NOTE: there is `HybridMambaAttentionDynamicCache.__bool__` to save.
# self.original_values["BambaPreTrainedModel._init_weights"] = BambaPreTrainedModel._init_weights

BambaMixer.torch_forward = _bamba_mixer_torch_forward
BambaModel._update_mamba_mask = _bamba_model_update_mamba_mask
BambaModel._update_causal_mask = _bamba_model_update_causal_mask
HybridMambaAttentionDynamicCache.__bool__ = _cache_bool
# BambaPreTrainedModel._init_weights = _bamba_pretrained_model_init_weights

def _revert_patch(self):
BambaMixer.torch_forward = self.original_values["BambaMixer.torch_forward"]
BambaModel._update_mamba_mask = self.original_values["BambaModel._update_mamba_mask"]
BambaModel._update_causal_mask = self.original_values["BambaModel._update_causal_mask"]
del HybridMambaAttentionDynamicCache.__bool__
# BambaPreTrainedModel._init_weights = self.original_values[
# "BambaPreTrainedModel._init_weights"
Expand Down
107 changes: 106 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,82 @@ def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask)


def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=True,
)


def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)


def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=True,
)


def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)


def _grouped_attn_pattern_8(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=True,
)


def _grouped_attn_replacement_8(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale
)


def _grouped_attn_pattern_9(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=True,
)


def _grouped_attn_replacement_9(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale
)


@TransformRegistry.register("match_repeat_kv")
class MatchRepeatKV(BaseTransform):
"""
Expand Down Expand Up @@ -434,6 +510,7 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
dummy_args_3 = [q, k1, v1, n_rep, attn_mask]
dummy_args_4 = [q, k1, v1, dropout, scale]

register_ad_pattern(
search_fn=_grouped_attn_pattern_1,
Expand Down Expand Up @@ -477,6 +554,35 @@ def register_grouped_attention(patterns: ADPatternMatcherPass):
scalar_workaround={"n_rep": n_rep},
)

register_ad_pattern(
search_fn=_grouped_attn_pattern_6,
replace_fn=_grouped_attn_replacement_6,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_7,
replace_fn=_grouped_attn_replacement_7,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_8,
replace_fn=_grouped_attn_replacement_8,
patterns=patterns,
dummy_args=dummy_args_4,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_9,
replace_fn=_grouped_attn_replacement_9,
patterns=patterns,
dummy_args=dummy_args_4,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)

num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
if num_grouped_patterns == 0:
ad_logger.warning(
Expand Down Expand Up @@ -529,7 +635,6 @@ def _apply(

# List of SDPA operations to look for
sdpa_ops = {
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def test_bamba_patches(
},
)

torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

factory = llm_args.create_factory()
model = factory.build_model("meta")
tokenizer = factory.init_tokenizer()
Expand Down
Loading