Skip to content

Commit 23db187

Browse files
authored
Generate: handle cache_position update in generate (#29467)
1 parent 7b87ecb commit 23db187

File tree

5 files changed

+155
-78
lines changed

5 files changed

+155
-78
lines changed

src/transformers/cache_utils.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import torch
55

66
from .configuration_utils import PretrainedConfig
7+
from .utils import logging
8+
9+
10+
logger = logging.get_logger(__name__)
711

812

913
@dataclass
@@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
5761
return max_length - new_seq_length
5862
return previous_seq_length
5963

64+
@property
65+
def seen_tokens(self):
66+
logger.warning_once(
67+
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
68+
"model input instead."
69+
)
70+
if hasattr(self, "_seen_tokens"):
71+
return self._seen_tokens
72+
else:
73+
return None
74+
6075

6176
class DynamicCache(Cache):
6277
"""
@@ -69,7 +84,7 @@ class DynamicCache(Cache):
6984
def __init__(self) -> None:
7085
self.key_cache: List[torch.Tensor] = []
7186
self.value_cache: List[torch.Tensor] = []
72-
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
87+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
7388

7489
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
7590
"""
@@ -121,7 +136,7 @@ def update(
121136
"""
122137
# Update the number of seen tokens
123138
if layer_idx == 0:
124-
self.seen_tokens += key_states.shape[-2]
139+
self._seen_tokens += key_states.shape[-2]
125140

126141
# Update the cache
127142
if len(self.key_cache) <= layer_idx:
@@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
191206
self.window_length = window_length
192207
self.num_sink_tokens = num_sink_tokens
193208
self.cos_sin_cache = {}
194-
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
209+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
195210

196211
@staticmethod
197212
def _rotate_half(x):
@@ -272,7 +287,7 @@ def update(
272287

273288
# Update the number of seen tokens
274289
if layer_idx == 0:
275-
self.seen_tokens += key_states.shape[-2]
290+
self._seen_tokens += key_states.shape[-2]
276291

277292
# [bsz, num_heads, seq_len, head_dim]
278293
if len(self.key_cache) <= layer_idx:
@@ -398,16 +413,11 @@ def update(
398413

399414
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
400415
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
401-
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
402-
raise ValueError(
403-
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
404-
)
405-
406-
def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
407-
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
408-
raise ValueError(
409-
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
410-
)
416+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
417+
# limit the check to the first batch member and head dimension.
418+
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
419+
# https://github.com/pytorch/pytorch/issues/120248 is fixed
420+
return (self.key_cache[0, 0].any(dim=-1)).sum()
411421

412422
def get_max_length(self) -> Optional[int]:
413423
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""

src/transformers/generation/utils.py

+63-16
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation(
633633
model_kwargs: Dict[str, Any],
634634
is_encoder_decoder: bool = False,
635635
standardize_cache_format: bool = False,
636-
model_inputs: Optional[Dict[str, Any]] = None,
637636
) -> Dict[str, Any]:
638637
# update past_key_values
639638
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
@@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation(
663662
dim=-1,
664663
)
665664

666-
model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
665+
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
666+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
667667

668668
return model_kwargs
669669

@@ -1931,10 +1931,15 @@ def _contrastive_search(
19311931
)
19321932

19331933
# keep track of which sequences are already finished
1934-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1934+
batch_size, cur_len = (
1935+
model_kwargs["attention_mask"].shape
1936+
if model_kwargs.get("attention_mask", None) is not None
1937+
else input_ids.shape
1938+
)
1939+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
1940+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
19351941

19361942
this_peer_finished = False # used by synced_gpus only
1937-
batch_size = input_ids.shape[0]
19381943

19391944
while True:
19401945
if synced_gpus:
@@ -1975,7 +1980,6 @@ def _contrastive_search(
19751980
model_kwargs,
19761981
is_encoder_decoder=self.config.is_encoder_decoder,
19771982
standardize_cache_format=True,
1978-
model_inputs=model_inputs,
19791983
)
19801984
if not sequential:
19811985
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
@@ -2170,7 +2174,9 @@ def _contrastive_search(
21702174
if streamer is not None:
21712175
streamer.put(next_tokens.cpu())
21722176
model_kwargs = self._update_model_kwargs_for_generation(
2173-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
2177+
outputs,
2178+
model_kwargs,
2179+
is_encoder_decoder=self.config.is_encoder_decoder,
21742180
)
21752181

21762182
# if eos_token was found in one sentence, set sentence to finished
@@ -2389,7 +2395,13 @@ def _greedy_search(
23892395
)
23902396

23912397
# keep track of which sequences are already finished
2392-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
2398+
batch_size, cur_len = (
2399+
model_kwargs["attention_mask"].shape
2400+
if model_kwargs.get("attention_mask", None) is not None
2401+
else input_ids.shape
2402+
)
2403+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
2404+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
23932405

23942406
this_peer_finished = False # used by synced_gpus only
23952407
while True:
@@ -2459,7 +2471,6 @@ def _greedy_search(
24592471
outputs,
24602472
model_kwargs,
24612473
is_encoder_decoder=self.config.is_encoder_decoder,
2462-
model_inputs=model_inputs,
24632474
)
24642475

24652476
# if eos_token was found in one sentence, set sentence to finished
@@ -2688,7 +2699,13 @@ def _sample(
26882699
)
26892700

26902701
# keep track of which sequences are already finished
2691-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
2702+
batch_size, cur_len = (
2703+
model_kwargs["attention_mask"].shape
2704+
if model_kwargs.get("attention_mask", None) is not None
2705+
else input_ids.shape
2706+
)
2707+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
2708+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
26922709

26932710
this_peer_finished = False # used by synced_gpus only
26942711
# auto-regressive generation
@@ -2758,7 +2775,9 @@ def _sample(
27582775
if streamer is not None:
27592776
streamer.put(next_tokens.cpu())
27602777
model_kwargs = self._update_model_kwargs_for_generation(
2761-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
2778+
outputs,
2779+
model_kwargs,
2780+
is_encoder_decoder=self.config.is_encoder_decoder,
27622781
)
27632782

27642783
# if eos_token was found in one sentence, set sentence to finished
@@ -3003,6 +3022,7 @@ def _beam_search(
30033022
num_beams = beam_scorer.num_beams
30043023

30053024
batch_beam_size, cur_len = input_ids.shape
3025+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
30063026

30073027
if num_beams * batch_size != batch_beam_size:
30083028
raise ValueError(
@@ -3156,7 +3176,9 @@ def _beam_search(
31563176
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
31573177

31583178
model_kwargs = self._update_model_kwargs_for_generation(
3159-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
3179+
outputs,
3180+
model_kwargs,
3181+
is_encoder_decoder=self.config.is_encoder_decoder,
31603182
)
31613183
if model_kwargs.get("past_key_values", None) is not None:
31623184
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
@@ -3397,6 +3419,7 @@ def _beam_sample(
33973419
num_beams = beam_scorer.num_beams
33983420

33993421
batch_beam_size, cur_len = input_ids.shape
3422+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
34003423

34013424
# init attention / hidden states / scores tuples
34023425
scores = () if (return_dict_in_generate and output_scores) else None
@@ -3510,7 +3533,9 @@ def _beam_sample(
35103533
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
35113534

35123535
model_kwargs = self._update_model_kwargs_for_generation(
3513-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
3536+
outputs,
3537+
model_kwargs,
3538+
is_encoder_decoder=self.config.is_encoder_decoder,
35143539
)
35153540
if model_kwargs.get("past_key_values", None) is not None:
35163541
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
@@ -3747,6 +3772,7 @@ def _group_beam_search(
37473772
device = input_ids.device
37483773

37493774
batch_beam_size, cur_len = input_ids.shape
3775+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
37503776

37513777
if return_dict_in_generate and output_scores:
37523778
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
@@ -3916,7 +3942,9 @@ def _group_beam_search(
39163942
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
39173943

39183944
model_kwargs = self._update_model_kwargs_for_generation(
3919-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
3945+
outputs,
3946+
model_kwargs,
3947+
is_encoder_decoder=self.config.is_encoder_decoder,
39203948
)
39213949
if model_kwargs.get("past_key_values", None) is not None:
39223950
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
@@ -4155,6 +4183,7 @@ def _constrained_beam_search(
41554183
num_beams = constrained_beam_scorer.num_beams
41564184

41574185
batch_beam_size, cur_len = input_ids.shape
4186+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
41584187

41594188
if num_beams * batch_size != batch_beam_size:
41604189
raise ValueError(
@@ -4275,7 +4304,9 @@ def _constrained_beam_search(
42754304

42764305
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
42774306
model_kwargs = self._update_model_kwargs_for_generation(
4278-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
4307+
outputs,
4308+
model_kwargs,
4309+
is_encoder_decoder=self.config.is_encoder_decoder,
42794310
)
42804311
if model_kwargs.get("past_key_values", None) is not None:
42814312
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
@@ -4511,7 +4542,13 @@ def _assisted_decoding(
45114542
)
45124543

45134544
# keep track of which sequences are already finished
4514-
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
4545+
batch_size, cur_len = batch_size, cur_len = (
4546+
model_kwargs["attention_mask"].shape
4547+
if model_kwargs.get("attention_mask", None) is not None
4548+
else input_ids.shape
4549+
)
4550+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
4551+
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
45154552

45164553
# other auxiliary variables
45174554
max_len = stopping_criteria[0].max_length
@@ -4555,6 +4592,14 @@ def _assisted_decoding(
45554592
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
45564593
)
45574594
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
4595+
if "cache_position" in candidate_kwargs:
4596+
candidate_kwargs["cache_position"] = torch.cat(
4597+
(
4598+
candidate_kwargs["cache_position"],
4599+
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
4600+
),
4601+
dim=0,
4602+
)
45584603

45594604
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
45604605

@@ -4673,7 +4718,9 @@ def _assisted_decoding(
46734718
)
46744719

46754720
model_kwargs = self._update_model_kwargs_for_generation(
4676-
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
4721+
outputs,
4722+
model_kwargs,
4723+
is_encoder_decoder=self.config.is_encoder_decoder,
46774724
)
46784725

46794726
# if eos_token was found in one sentence, set sentence to finished

0 commit comments

Comments
 (0)