-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Cache: new Cache format in decoder-only models #31421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
183cd66
4578bca
9505ca4
2ab28f3
5fe4e9e
09413c3
3c27604
350acc5
c0adf10
c18b177
582f289
3141a71
33d54b4
dd05e6b
cb878d5
0588791
a27b47c
1abcf30
00ed88c
6c3b3aa
fd5eeab
e233f29
356d578
c906670
08d9e6f
56c05b2
8510810
cebb55d
8fd9dd1
4b9ced1
aea219b
4991863
ec306a2
cf793b7
c92409c
c2b97e4
35b60de
d2fca9a
0933350
42349d4
45c3a1b
5f22616
f5af6a2
21b45c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| from torch.nn import CrossEntropyLoss | ||
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache | ||
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | ||
| from ...modeling_utils import PreTrainedModel | ||
| from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging | ||
|
|
@@ -57,7 +58,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten | |
|
|
||
|
|
||
| class CodeGenAttention(nn.Module): | ||
| def __init__(self, config): | ||
| def __init__(self, config, layer_idx=None): | ||
| super().__init__() | ||
|
|
||
| max_positions = config.max_position_embeddings | ||
|
|
@@ -71,6 +72,13 @@ def __init__(self, config): | |
|
|
||
| self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||
| self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||
| self.layer_idx = layer_idx | ||
| if layer_idx is None: | ||
| logger.warning_once( | ||
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | ||
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | ||
| "when creating this class." | ||
| ) | ||
|
|
||
| self.embed_dim = config.hidden_size | ||
| self.num_attention_heads = config.num_attention_heads | ||
|
|
@@ -150,7 +158,7 @@ def _attn( | |
| def forward( | ||
| self, | ||
| hidden_states: Optional[torch.FloatTensor], | ||
| layer_past: Optional[Tuple[torch.Tensor]] = None, | ||
| layer_past: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| head_mask: Optional[torch.FloatTensor] = None, | ||
|
|
@@ -200,18 +208,11 @@ def forward( | |
| key = key.permute(0, 2, 1, 3) | ||
| query = query.permute(0, 2, 1, 3) | ||
|
|
||
| # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. | ||
| # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 | ||
| if layer_past is not None: | ||
| past_key = layer_past[0] | ||
| past_value = layer_past[1] | ||
| key = torch.cat((past_key, key), dim=-2) | ||
| value = torch.cat((past_value, value), dim=-2) | ||
|
|
||
| if use_cache is True: | ||
| # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. | ||
| # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 | ||
| present = (key.to(hidden_states.dtype), value) | ||
| else: | ||
| present = None | ||
| cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_dim} | ||
| key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) | ||
|
|
||
| # compute self-attention: V x Softmax(QK^T) | ||
| attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) | ||
|
|
@@ -220,7 +221,7 @@ def forward( | |
| attn_output = self.out_proj(attn_output) | ||
| attn_output = self.resid_dropout(attn_output) | ||
|
|
||
| outputs = (attn_output, present) | ||
| outputs = (attn_output, layer_past) | ||
| if output_attentions: | ||
| outputs += (attn_weights,) | ||
|
|
||
|
|
@@ -250,17 +251,17 @@ def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTens | |
| # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen | ||
| class CodeGenBlock(nn.Module): | ||
| # Ignore copy | ||
| def __init__(self, config): | ||
| def __init__(self, config, layer_idx=None): | ||
| super().__init__() | ||
| inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd | ||
| self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | ||
| self.attn = CodeGenAttention(config) | ||
| self.attn = CodeGenAttention(config, layer_idx) | ||
| self.mlp = CodeGenMLP(inner_dim, config) | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: Optional[torch.FloatTensor], | ||
| layer_past: Optional[Tuple[torch.Tensor]] = None, | ||
| layer_past: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| head_mask: Optional[torch.FloatTensor] = None, | ||
|
|
@@ -303,6 +304,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): | |
| supports_gradient_checkpointing = True | ||
| _no_split_modules = ["CodeGenBlock"] | ||
| _skip_keys_device_placement = "past_key_values" | ||
| _supports_cache_class = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget _supports_quantized_cache = True
_supports_static_cache = True(if appropriate, on this and other models)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, yes, forgot about in any case I'll make a following PR to check fullgraph compile and if it works as-is add the tests in each modeling
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 We should add a mixin test for fullgraph compilation when Moreover, after this PR, I think
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I will test and add a mixin test (one lightweight and one slow maybe) after this PR is merged. For now I added the flags and tested via running generation tests
Btw, GIT will be an exception which supports cache class but not static cache as it has some special attn mask preparation steps
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW we could also just add the class to the set |
||
|
|
||
| def __init__(self, *inputs, **kwargs): | ||
| super().__init__(*inputs, **kwargs) | ||
|
|
@@ -374,6 +376,23 @@ def _init_weights(self, module): | |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | ||
| is useful if you want more control over how to convert *input_ids* indices into associated vectors than the | ||
| model's internal embedding lookup matrix. | ||
| past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): | ||
| Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | ||
| blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` | ||
| returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. | ||
|
|
||
| Two formats are allowed: | ||
| - a [`~cache_utils.Cache`] instance; | ||
| - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of | ||
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy | ||
| cache format. | ||
|
|
||
| The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the | ||
| legacy cache format will be returned. | ||
|
|
||
| If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't | ||
| have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` | ||
| of shape `(batch_size, sequence_length)`. | ||
| output_attentions (`bool`, *optional*): | ||
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | ||
| tensors for more detail. | ||
|
|
@@ -397,7 +416,7 @@ def __init__(self, config): | |
| self.vocab_size = config.vocab_size | ||
| self.wte = nn.Embedding(config.vocab_size, self.embed_dim) | ||
| self.drop = nn.Dropout(config.embd_pdrop) | ||
| self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)]) | ||
| self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)]) | ||
| self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) | ||
| self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) | ||
|
|
||
|
|
@@ -421,7 +440,7 @@ def set_input_embeddings(self, new_embeddings): | |
| def forward( | ||
| self, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
| past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| token_type_ids: Optional[torch.LongTensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
|
|
@@ -457,11 +476,12 @@ def forward( | |
| if token_type_ids is not None: | ||
| token_type_ids = token_type_ids.view(-1, input_shape[-1]) | ||
|
|
||
| if past_key_values is None: | ||
| past_length = 0 | ||
| past_key_values = tuple([None] * len(self.h)) | ||
| else: | ||
| past_length = past_key_values[0][0].size(-2) | ||
| past_length = 0 | ||
| if use_cache: | ||
| use_legacy_cache = not isinstance(past_key_values, Cache) | ||
| if use_legacy_cache: | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| past_length = past_key_values.get_seq_length() | ||
|
||
|
|
||
| if position_ids is None: | ||
| position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | ||
|
|
@@ -514,10 +534,10 @@ def forward( | |
| ) | ||
| use_cache = False | ||
|
|
||
| presents = () if use_cache else None | ||
| next_decoder_cache = None | ||
| all_self_attentions = () if output_attentions else None | ||
| all_hidden_states = () if output_hidden_states else None | ||
| for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): | ||
| for i, block in enumerate(self.h): | ||
| if output_hidden_states: | ||
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
||
|
|
@@ -535,7 +555,7 @@ def forward( | |
| else: | ||
| outputs = block( | ||
| hidden_states=hidden_states, | ||
| layer_past=layer_past, | ||
| layer_past=past_key_values, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| head_mask=head_mask[i], | ||
|
|
@@ -545,7 +565,7 @@ def forward( | |
|
|
||
| hidden_states = outputs[0] | ||
| if use_cache is True: | ||
| presents = presents + (outputs[1],) | ||
| next_decoder_cache = outputs[1] | ||
|
|
||
| if output_attentions: | ||
| all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) | ||
|
|
@@ -557,12 +577,18 @@ def forward( | |
| if output_hidden_states: | ||
| all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
|
||
| next_cache = None | ||
| if use_cache: | ||
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache | ||
|
|
||
| if not return_dict: | ||
| return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) | ||
| return tuple( | ||
| v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None | ||
| ) | ||
|
|
||
| return BaseModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=presents, | ||
| past_key_values=next_cache, | ||
| hidden_states=all_hidden_states, | ||
| attentions=all_self_attentions, | ||
| ) | ||
|
|
@@ -593,9 +619,12 @@ def set_output_embeddings(self, new_embeddings): | |
|
|
||
| def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): | ||
| token_type_ids = kwargs.get("token_type_ids", None) | ||
| attention_mask = kwargs.get("attention_mask", None) | ||
| past_length = 0 | ||
| # Omit tokens covered by past_key_values | ||
| if past_key_values: | ||
| past_length = past_key_values[0][0].shape[2] | ||
| past_length = cache_length = past_key_values.get_seq_length() | ||
| max_cache_length = past_key_values.get_max_length() | ||
|
|
||
| # Some generation methods already pass only the last input ID | ||
| if input_ids.shape[1] > past_length: | ||
|
|
@@ -608,7 +637,14 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ | |
| if token_type_ids is not None: | ||
| token_type_ids = token_type_ids[:, -input_ids.shape[1] :] | ||
|
|
||
| attention_mask = kwargs.get("attention_mask", None) | ||
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | ||
| if ( | ||
| max_cache_length is not None | ||
| and attention_mask is not None | ||
| and cache_length + input_ids.shape[1] > max_cache_length | ||
| ): | ||
| attention_mask = attention_mask[:, -max_cache_length:] | ||
|
|
||
| position_ids = kwargs.get("position_ids", None) | ||
|
|
||
| if attention_mask is not None and position_ids is None: | ||
|
|
@@ -619,7 +655,7 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | ||
|
|
||
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
| if inputs_embeds is not None and past_key_values is None: | ||
| if inputs_embeds is not None and past_length == 0: | ||
| model_inputs = {"inputs_embeds": inputs_embeds} | ||
| else: | ||
| model_inputs = {"input_ids": input_ids.contiguous()} | ||
|
|
@@ -644,7 +680,7 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ | |
| def forward( | ||
| self, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
| past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None, | ||
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| token_type_ids: Optional[torch.LongTensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -622,6 +622,7 @@ def _flash_attention_forward( | |
| """ | ||
| Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
| first unpad the input, then computes the attention scores and pad the final attention scores. | ||
|
|
||
|
||
| Args: | ||
| query_states (`torch.Tensor`): | ||
| Input query states to be passed to Flash Attention API | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.