-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Cache: don't throw warnings on gemma2 when instantiating a new cache
#33595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
gemma2 when instantiating a new cache
| def get_seq_length(self, layer_idx: Optional[int] = 0): | ||
| return None | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: deprecate this function in favor of `cache_position` | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HybridCache is a StaticCache with alternating sliding window layers. The method to retrieve the cache length is copy/paste from StaticCache
We will want to use another method in the future, but let's leave this as a copy of StaticCache for now. This method is needed in the updated gemma 2.
| raise ValueError("When `past_key_values` is passed, `cache_position` must be too") | ||
|
|
||
| # Probably a forward call with caching, so we set up cache for one call only | ||
| if use_cache and past_key_values is None and not self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two changes here, both to be consistent with other models:
self.trainingshould not control whether we instantiate a cache- If a user respects the types in the docs,
past_key_valuesis either aCacheor we instantiate a new one for the user without warnings
| dtype=inputs_embeds.dtype, | ||
| ) | ||
|
|
||
| if cache_position is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy/paste from llama (and other Cache-supporting models)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp yes, if get_seq_length gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey sounds good, as long as the function of get_seq_length is transparent for users, to reduce number of cache-related question we get 😄
|
|
||
| if use_cache and past_key_values is None and not self.training: | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| if use_cache and not isinstance(past_key_values, Cache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy/paste from llama (and other Cache-supporting models)
| def test_model_outputs_equivalence(self, **kwargs): | ||
| pass | ||
|
|
||
| @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this parameterized, the intended overwriting was not happening
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Please merge once @zucchini-nlp has approved as she knows this code more than I.
cc @BenjaminBossan as well
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for cleaning up warnings! Left one question about HybridCache, since I was reluctant to add seq-length for that cache type where lengths are not consistent over layers
| dtype=inputs_embeds.dtype, | ||
| ) | ||
|
|
||
| if cache_position is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?
|
I'm not qualified to review this but thanks for addressing this so quickly. |
What does this PR do?
Related to #33541
The warning in question should only be thrown in the case we are converting from a legacy cache, which will be deprecated soon. Gemma 2 doesn't support the legacy cache format, so no warning should ever be thrown :)
In the process, updates a few related inconsistencies.
✅ slow
gemma2tests ran locally. There are a few failures (also present on main). Some failures were fixed in this PR.