Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Sep 19, 2024

What does this PR do?

Related to #33541

The warning in question should only be thrown in the case we are converting from a legacy cache, which will be deprecated soon. Gemma 2 doesn't support the legacy cache format, so no warning should ever be thrown :)

In the process, updates a few related inconsistencies.


✅ slow gemma2 tests ran locally. There are a few failures (also present on main). Some failures were fixed in this PR.

@gante gante changed the title Cache: don't throw warnings on gemma 2 when instantiating a new cache Cache: don't throw warnings on gemma2 when instantiating a new cache Sep 19, 2024
Comment on lines 1662 to 1666
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridCache is a StaticCache with alternating sliding window layers. The method to retrieve the cache length is copy/paste from StaticCache

We will want to use another method in the future, but let's leave this as a copy of StaticCache for now. This method is needed in the updated gemma 2.

raise ValueError("When `past_key_values` is passed, `cache_position` must be too")

# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two changes here, both to be consistent with other models:

  1. self.training should not control whether we instantiate a cache
  2. If a user respects the types in the docs, past_key_values is either a Cache or we instantiate a new one for the user without warnings

dtype=inputs_embeds.dtype,
)

if cache_position is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp yes, if get_seq_length gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey sounds good, as long as the function of get_seq_length is transparent for users, to reduce number of cache-related question we get 😄


if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_cache and not isinstance(past_key_values, Cache):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this parameterized, the intended overwriting was not happening

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Please merge once @zucchini-nlp has approved as she knows this code more than I.

cc @BenjaminBossan as well

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for cleaning up warnings! Left one question about HybridCache, since I was reluctant to add seq-length for that cache type where lengths are not consistent over layers

dtype=inputs_embeds.dtype,
)

if cache_position is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?

@BenjaminBossan
Copy link
Member

I'm not qualified to review this but thanks for addressing this so quickly.

@gante gante merged commit 52920b5 into huggingface:main Sep 19, 2024
@gante gante deleted the gemma2_warning branch September 19, 2024 16:42
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants