@@ -987,7 +987,9 @@ def forward(
987987 if position_ids is None :
988988 position_ids = cache_position .unsqueeze (0 )
989989
990- causal_mask = self ._update_causal_mask (attention_mask , inputs_embeds , cache_position )
990+ causal_mask = self ._update_causal_mask (
991+ attention_mask , inputs_embeds , cache_position , past_seen_tokens + inputs_embeds .shape [1 ]
992+ )
991993
992994 # embed positions
993995 hidden_states = inputs_embeds
@@ -1055,7 +1057,7 @@ def forward(
10551057 # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
10561058 # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
10571059 # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1058- def _update_causal_mask (self , attention_mask , input_tensor , cache_position ):
1060+ def _update_causal_mask (self , attention_mask , input_tensor , cache_position , current_length ):
10591061 if self .config ._attn_implementation == "flash_attention_2" :
10601062 if attention_mask is not None and 0.0 in attention_mask :
10611063 return attention_mask
@@ -1068,7 +1070,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
10681070 target_length = self .config .max_position_embeddings
10691071 else : # dynamic cache
10701072 target_length = (
1071- attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else cache_position [ - 1 ] + 1
1073+ attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else current_length + 1
10721074 )
10731075
10741076 causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
0 commit comments