Skip to content

Commit

Permalink
Correctly inject required generate() fix
Browse files Browse the repository at this point in the history
Tested by calling endless_generate()
  • Loading branch information
tomaarsen committed Oct 14, 2023
1 parent e0ab568 commit 48bb293
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 43 deletions.
75 changes: 34 additions & 41 deletions attention_sinks/generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,44 @@
from typing import Any, Dict

import torch
from transformers.generation.utils import GenerationMixin as TGenerationMixin
from transformers.utils import ModelOutput


class GenerationMixin(TGenerationMixin):
"""
This GenerationMixin must be overridden to prevent the `attention_mask`
from extending beyond the window size.
"""
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
# Only this `if`-statement is changed, it's required to stop the attention_mask from extending itself too far
if model_kwargs["attention_mask"].size(-1) == model_kwargs["past_key_values"][0][0].size(2):
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
# Only this `if`-statement is changed, it's required to stop the attention_mask from extending itself too far
if model_kwargs["attention_mask"].size(-1) == model_kwargs["past_key_values"][0][0].size(2):
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)

return model_kwargs
return model_kwargs
9 changes: 7 additions & 2 deletions attention_sinks/inject_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.utils import logging

from attention_sinks.attention_sink_kv_cache import AttentionSinkKVCache
from attention_sinks.generation.utils import GenerationMixin
from attention_sinks.generation.utils import _update_model_kwargs_for_generation

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -35,7 +35,7 @@
}


class InjectAttentionSinksMixin(GenerationMixin):
class InjectAttentionSinksMixin:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Separate Attention Sink kwargs from regular kwargs
Expand Down Expand Up @@ -67,6 +67,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
f"[Attention Sinks] Injected Attention Sink KV Cache into {call_count} model class{'es' if call_count != 1 else ''}."
)

# Overwrite broken model kwargs, prevents indexing error when generating
# The default _update_model_kwargs_for_generation expects the seq_length to keep growing
# as generation occurs, but that isn't the case
model._update_model_kwargs_for_generation = types.MethodType(_update_model_kwargs_for_generation, model)

return model

@classmethod
Expand Down

0 comments on commit 48bb293

Please sign in to comment.