-
Notifications
You must be signed in to change notification settings - Fork 31.1k
FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules + revert #30070 at the same time #30317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -229,6 +229,8 @@ def fuse_awq_modules(model, quantization_config): | |||||||||||
| else: | ||||||||||||
| raise ValueError("Fusing is only supported for the AutoAWQ backend") | ||||||||||||
|
|
||||||||||||
| fused_attention_modules = [] | ||||||||||||
|
|
||||||||||||
| for name, module in model.named_modules(): | ||||||||||||
| if modules_to_not_convert is not None: | ||||||||||||
| if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert): | ||||||||||||
|
|
@@ -241,7 +243,27 @@ def fuse_awq_modules(model, quantization_config): | |||||||||||
| _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP) | ||||||||||||
|
|
||||||||||||
| # Replace attention layers | ||||||||||||
| _fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused) | ||||||||||||
| attention_has_been_fused = _fuse_awq_attention_layers( | ||||||||||||
| model, module, modules_to_fuse, name, QuantAttentionFused | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| if attention_has_been_fused: | ||||||||||||
| fused_attention_modules.append(name) | ||||||||||||
|
|
||||||||||||
| # For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass | ||||||||||||
| # `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt | ||||||||||||
| # by the `AttentionMaskConverter` module. | ||||||||||||
|
Comment on lines
+253
to
+255
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit but more understandable |
||||||||||||
| if len(fused_attention_modules) > 0: | ||||||||||||
| fused_attention_parent_modules = { | ||||||||||||
| fused_attention_module.split(".")[0] for fused_attention_module in fused_attention_modules | ||||||||||||
| } | ||||||||||||
| for module_name, module in model.named_modules(): | ||||||||||||
| if any( | ||||||||||||
| module_name in fused_attention_parent_module | ||||||||||||
| for fused_attention_parent_module in fused_attention_parent_modules | ||||||||||||
| ): | ||||||||||||
| if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"): | ||||||||||||
| module.config._attn_implementation = "custom" | ||||||||||||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
| return model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -332,8 +354,10 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na | |||||||||||
| """ | ||||||||||||
| from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV | ||||||||||||
|
|
||||||||||||
| module_has_been_fused = False | ||||||||||||
|
|
||||||||||||
| if len(modules_to_fuse["attention"]) == 0: | ||||||||||||
| return | ||||||||||||
| return module_has_been_fused | ||||||||||||
|
|
||||||||||||
| if hasattr(module, modules_to_fuse["attention"][0]): | ||||||||||||
| # First, we pack the QKV layers together | ||||||||||||
|
|
@@ -394,6 +418,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na | |||||||||||
| setattr(parent, child_name, fused_attention_layer.to(previous_device)) | ||||||||||||
|
|
||||||||||||
| del q_proj, k_proj, v_proj, o_proj | ||||||||||||
| module_has_been_fused = True | ||||||||||||
|
|
||||||||||||
| return module_has_been_fused | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def post_init_awq_exllama_modules(model, exllama_config): | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,59 @@ def _unmask_unattended( | |
|
|
||
| return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) | ||
|
|
||
| @staticmethod | ||
| def _ignore_causal_mask_sdpa( | ||
| attention_mask: Optional[torch.Tensor], | ||
| inputs_embeds: torch.Tensor, | ||
| past_key_values_length: int, | ||
| sliding_window: Optional[int] = None, | ||
| ) -> bool: | ||
| """ | ||
| Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. | ||
|
|
||
| In case no token is masked in the `attention_mask` argument, if `query_length == 1` or | ||
| `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, | ||
| allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). | ||
| """ | ||
|
|
||
| batch_size, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] | ||
| key_value_length = query_length + past_key_values_length | ||
|
|
||
| is_tracing = ( | ||
| torch.jit.is_tracing() | ||
| or isinstance(inputs_embeds, torch.fx.Proxy) | ||
| or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | ||
| ) | ||
|
|
||
| ignore_causal_mask = False | ||
|
|
||
| if attention_mask is None: | ||
| # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or | ||
| # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). | ||
| # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. | ||
| # | ||
| # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). | ||
| if sliding_window is None or key_value_length < sliding_window: | ||
| ignore_causal_mask = not is_tracing | ||
|
||
| elif sliding_window is None or key_value_length < sliding_window: | ||
| if len(attention_mask.shape) == 4: | ||
| expected_shape = (batch_size, 1, query_length, key_value_length) | ||
| if tuple(attention_mask.shape) != expected_shape: | ||
| raise ValueError( | ||
| f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." | ||
| ) | ||
| elif not is_tracing and torch.all(attention_mask == 1): | ||
| if query_length == 1 or key_value_length == query_length: | ||
| # For query_length == 1, causal attention and bi-directional attention are the same. | ||
| ignore_causal_mask = True | ||
|
|
||
| # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation | ||
| # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. | ||
| # Reference: https://github.com/pytorch/pytorch/issues/108108 | ||
| # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. | ||
|
|
||
| return ignore_causal_mask | ||
|
|
||
|
|
||
| def _prepare_4d_causal_attention_mask( | ||
| attention_mask: Optional[torch.Tensor], | ||
|
|
@@ -305,7 +358,6 @@ def _prepare_4d_causal_attention_mask_for_sdpa( | |
| attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) | ||
|
|
||
| key_value_length = input_shape[-1] + past_key_values_length | ||
| _, query_length = input_shape | ||
|
|
||
| # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` | ||
| # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. | ||
|
|
@@ -316,6 +368,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa( | |
| or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) | ||
| ) | ||
|
|
||
| <<<<<<< HEAD | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ignore_causal_mask = False | ||
|
|
||
| if attention_mask is None: | ||
|
|
@@ -351,6 +404,14 @@ def _prepare_4d_causal_attention_mask_for_sdpa( | |
| # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation | ||
| # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. | ||
| # Reference: https://github.com/pytorch/pytorch/issues/108108 | ||
| ======= | ||
| ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( | ||
| attention_mask=attention_mask, | ||
| inputs_embeds=inputs_embeds, | ||
| past_key_values_length=past_key_values_length, | ||
| sliding_window=sliding_window, | ||
| ) | ||
| >>>>>>> parent of acab997bef... Revert "Re-enable SDPA's FA2 path (#30070)" (#30314) | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if ignore_causal_mask: | ||
| expanded_4d_mask = None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.