-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Phi3: fix attn for sliding window #33586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening the PR with the fix 💪 A few questions and comments
| min_dtype = torch.finfo(dtype).min | ||
| sequence_length = input_tensor.shape[1] | ||
| if using_static_cache: | ||
| if using_sliding_window_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we now have a significant code block to determine target_length, but target_length is not used directly in _update_causal_mask
suggestion to avoid this disconnection:
- make
SlidingWindowCachestore theconfig.sliding_windowit receives at init time - move the logic that computes
target_lengthinside_prepare_4d_causal_attention_mask_with_cache_position _prepare_4d_causal_attention_mask_with_cache_positionno longer receivestarget_lengthnorconfig, as they can be retrieved frompast_key_values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the target length, maybe we merge #32421 first where the max length should be easily accessible through cache no matter cache type. We earlier discussed returning max cache shape, and sliding window has a max capacity for cache class
And then yeah, we can move that part to _prepare_4d_causal_attention_mask_with_cache_position but I think we'd better move it in all classes for general consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed with the plan -- shall we leave a TODO for us in this PR, linking to your comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'll add a TODO for us so we can make the required change in all models at once, to not mix different updates in one PR :)
| dtype = self.lm_head.weight.dtype | ||
| min_dtype = torch.finfo(dtype).min | ||
|
|
||
| if isinstance(past_key_values, SlidingWindowCache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same comment as above here)
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome test!
|
Done, I think this PR is ready to be reviewed/merged. Depending on which PR is merged first, this or the linked one, I will rebase and apply necessary changes. Then I'll add the TODO comment for moving |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's try to abstract a tad bit, as we generally would avoid differentiating cache classes in the modeling!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this LGTM! thanks for the thorough test
| if isinstance(past_key_values, SlidingWindowCache): | ||
| sliding_window = ( | ||
| self.config.sliding_window if self.config.sliding_window is not None else sequence_length | ||
| ) | ||
| target_length = max(sequence_length, sliding_window) | ||
| else: | ||
| target_length = past_key_values.get_max_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we need to do this: if there is a sliding WindowCache it was init from the config and thus has a correct sliding_window.
Then, SlidingWindowCache. get_max_length should take sequence_length as input to return the max and avoid having these checks here WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should but the PR for getting max_length on slidingWindow cache is on its way and not merged yet.
So, to update you on our discussions with @gante which are currently in different PR comments: some cache classes now do not have a max_length (e.g. SlidingWindow). As commented in the code, sliding window technically has no max length and goes on a rolling basis. But in transformers what we want to check is the "maximum capacity of cache instance", independently of how cache handles new tokens going beyond that capacity.
So, in a different PR I changed naming to get_max_cache_shape which is more straightforward and added get_max_cache_shape for Sliding Window cache. We'll do a simple deprecation cycle, as we did for static cache's "max_batch_size". Until the linked PR is merged, I am copying this piece of code from mistral and using it in phi3. I have it noted and will handle it depending which one gets merged first :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright sounds good. I just don't want us to add too much complexity to the code! 🤗
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but the changes will likely fail because of a recently merged PR (things need to be moved, see comment)
| _CONFIG_FOR_DOC = "MistralConfig" | ||
|
|
||
|
|
||
| def _prepare_4d_causal_attention_mask_with_cache_position( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of #33677, this function is part of the model class -- I think you will have to move the diff there, otherwise tests may fail on main
(see the diff in that PR for Llama, it should be similar to the changes you need to do here)
|
Rebased main and updated accordingly by moving Will be merging tomorrow if no comment remain :) |
|
Hey @zucchini-nlp, while working on #33619 I had issues with the 4d masks and just found this PR - however, it is not only an issue for Phi3! For what I could see, the following models have the exact same issue (
Let me know if you can fix it or if you want me to jump on it. |
|
@Cyrilvallez oh I see, didn't know we had more models that support sliding window. I can propagate changes to other models, sure :) |
|
Yes! I think I listed them all but you can maybe double-check so that all of them get correctly fixed 🤗 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ping me when merge this way I can put it in a patch!
|
@ArthurZucker done! I had to change tests for Qwen2 models because otherwise we won't get same results for long padded input as for the base input. Applying sliding mask results in minor differences |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good for the test modifications, if run-slow were good let's go! 🔥
* fix phi3 attn fir sliding window * fix tests * address most comment * style * update after rebase * add more models * fix tests
What does this PR do?
Fixes #32945. The reason is that Phi3 prev prepared 4D attn with sliding window while the new
updte_causal_maskdidn't take that into account. This PR fixes it and adds a test