Skip to content

Conversation

@gante
Copy link
Member

@gante gante commented Mar 19, 2024

What does this PR do?

Reintroduces support for partial 4D masks in Llama (and other models with support for the static cache).

Fixes #29525

Thank you @poedator for a clean description and a test case -- I was unaware our previous versions supported this use of 4D attention masks 🤗

@gante gante requested a review from amyeroberts March 19, 2024 13:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this and the tests ❤️

@poedator
Copy link
Contributor

@gante, thank you for the PR!
I tested it with more examples and it works OK.

I have couple of suggestions:

  • add a few lines to documentation on how to submit custom attention masks properly.

  • add ability to pass a custom attention mask unchanged if its shape matches self.causal_mask.shape. Ideally do this before any line where causal_mask is re-created to save time and RAM. My usecase is to have a static mask for drafting model in spec decoding and only updating its elements. I'd like this to pass torch.compile.

  • if you are open to extreme time saving (~ 200 µs for 4k attention_mask), the It may even be assumed that the custom mask has the right dtype, and has min_dtype in place of zeros and 0 in place of ones to avoid conversions like (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype

@gante
Copy link
Member Author

gante commented Mar 19, 2024

Hi @poedator 👋

We have other bug fixes and more impactful features in our pipeline, so I'll not work on your suggestions (at least not for now) :) However, we're always open to PRs!

@gante gante merged commit 4294f0c into huggingface:main Mar 19, 2024
@gante gante deleted the partial_4d_masks branch March 19, 2024 17:32
ArthurZucker pushed a commit that referenced this pull request Mar 20, 2024
* partial 4d masks

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

custom 4d attention masks broken by #28937

4 participants