Skip to content

Conversation

@zucchini-nlp
Copy link
Member

What does this PR do?

Fixes #32945. The reason is that Phi3 prev prepared 4D attn with sliding window while the new updte_causal_mask didn't take that into account. This PR fixes it and adds a test

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening the PR with the fix 💪 A few questions and comments

min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
if using_sliding_window_cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we now have a significant code block to determine target_length, but target_length is not used directly in _update_causal_mask

suggestion to avoid this disconnection:

  1. make SlidingWindowCache store the config.sliding_window it receives at init time
  2. move the logic that computes target_length inside _prepare_4d_causal_attention_mask_with_cache_position
  3. _prepare_4d_causal_attention_mask_with_cache_position no longer receives target_length nor config, as they can be retrieved from past_key_values

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the target length, maybe we merge #32421 first where the max length should be easily accessible through cache no matter cache type. We earlier discussed returning max cache shape, and sliding window has a max capacity for cache class

And then yeah, we can move that part to _prepare_4d_causal_attention_mask_with_cache_position but I think we'd better move it in all classes for general consistency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with the plan -- shall we leave a TODO for us in this PR, linking to your comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll add a TODO for us so we can make the required change in all models at once, to not mix different updates in one PR :)

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

if isinstance(past_key_values, SlidingWindowCache):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same comment as above here)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome test!

@zucchini-nlp
Copy link
Member Author

Done, I think this PR is ready to be reviewed/merged.

Depending on which PR is merged first, this or the linked one, I will rebase and apply necessary changes. Then I'll add the TODO comment for moving target_length logic to _prepare_attention_mask

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's try to abstract a tad bit, as we generally would avoid differentiating cache classes in the modeling!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this LGTM! thanks for the thorough test

Comment on lines 1189 to 1195
if isinstance(past_key_values, SlidingWindowCache):
sliding_window = (
self.config.sliding_window if self.config.sliding_window is not None else sequence_length
)
target_length = max(sequence_length, sliding_window)
else:
target_length = past_key_values.get_max_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need to do this: if there is a sliding WindowCache it was init from the config and thus has a correct sliding_window.

Then, SlidingWindowCache. get_max_length should take sequence_length as input to return the max and avoid having these checks here WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should but the PR for getting max_length on slidingWindow cache is on its way and not merged yet.

So, to update you on our discussions with @gante which are currently in different PR comments: some cache classes now do not have a max_length (e.g. SlidingWindow). As commented in the code, sliding window technically has no max length and goes on a rolling basis. But in transformers what we want to check is the "maximum capacity of cache instance", independently of how cache handles new tokens going beyond that capacity.

So, in a different PR I changed naming to get_max_cache_shape which is more straightforward and added get_max_cache_shape for Sliding Window cache. We'll do a simple deprecation cycle, as we did for static cache's "max_batch_size". Until the linked PR is merged, I am copying this piece of code from mistral and using it in phi3. I have it noted and will handle it depending which one gets merged first :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright sounds good. I just don't want us to add too much complexity to the code! 🤗

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but the changes will likely fail because of a recently merged PR (things need to be moved, see comment)

_CONFIG_FOR_DOC = "MistralConfig"


def _prepare_4d_causal_attention_mask_with_cache_position(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of #33677, this function is part of the model class -- I think you will have to move the diff there, otherwise tests may fail on main

(see the diff in that PR for Llama, it should be similar to the changes you need to do here)

@zucchini-nlp
Copy link
Member Author

Rebased main and updated accordingly by moving prepare_causal_mask to XXXModel. Also, noticed Phi3Moe was added while PR was in progress and it is same as Phi3, so I propagated changes there too

Will be merging tomorrow if no comment remain :)

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Oct 9, 2024

Hey @zucchini-nlp, while working on #33619 I had issues with the 4d masks and just found this PR - however, it is not only an issue for Phi3! For what I could see, the following models have the exact same issue (AttentionMaskConverter._ignore_causal_mask_sdpa() does not check for the sliding_window resulting in wrong masks, and neither do _prepare_4d_causal_attention_mask_with_cache_position()).

  • Mimi
  • Mixtral
  • PhiMoe
  • Qwen2
  • Qwen2Moe
  • Qwen2VL
  • Starcoder2

Let me know if you can fix it or if you want me to jump on it.

@zucchini-nlp
Copy link
Member Author

@Cyrilvallez oh I see, didn't know we had more models that support sliding window. I can propagate changes to other models, sure :)

@Cyrilvallez
Copy link
Member

Yes! I think I listed them all but you can maybe double-check so that all of them get correctly fixed 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping me when merge this way I can put it in a patch!

@zucchini-nlp
Copy link
Member Author

@ArthurZucker done! I had to change tests for Qwen2 models because otherwise we won't get same results for long padded input as for the base input. Applying sliding mask results in minor differences

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good for the test modifications, if run-slow were good let's go! 🔥

@zucchini-nlp zucchini-nlp merged commit adea675 into huggingface:main Oct 10, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* fix phi3 attn fir sliding window

* fix tests

* address most comment

* style

* update after rebase

* add more models

* fix tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Regression in generating text with Phi-3-mini-4k-instruct with a long prompt (gibberish in v4.42+)

6 participants