diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index f9ae8deeb865..ef79261364ba 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -459,14 +459,22 @@ def cuda_kernels_forward( past_key_values: Lfm2HybridConvCache | None = None, attention_mask: torch.Tensor | None = None, ): + seqlen = x.shape[1] x = apply_mask_to_padding_states(x, attention_mask) BCx = self.in_proj(x).transpose(-1, -2) B, C, x = BCx.chunk(3, dim=-2) Bx = B * x + # Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated + # by the current layer + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # In this case, the cache was already updated and we need to subtract seq_len to get the correct past length + if "full_attention" in self.config.layer_types[: self.layer_idx]: + past_seen_tokens = past_seen_tokens - seqlen + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.get_seq_length() > 0: + if past_seen_tokens > 0: conv_out = causal_conv1d_update( Bx.squeeze(-1), past_key_values.conv_cache[self.layer_idx], @@ -507,7 +515,7 @@ def slow_forward( if "full_attention" in self.config.layer_types[: self.layer_idx]: past_seen_tokens = past_seen_tokens - seqlen - if past_key_values is not None and past_seen_tokens > 0: + if past_seen_tokens > 0: conv_state = past_key_values.conv_cache[self.layer_idx] cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens cache_position = cache_position.clamp(0, self.L_cache - 1) diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index eab687df803a..65119a287abd 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -294,14 +294,22 @@ def cuda_kernels_forward( past_key_values: Lfm2HybridConvCache | None = None, attention_mask: torch.Tensor | None = None, ): + seqlen = x.shape[1] x = apply_mask_to_padding_states(x, attention_mask) BCx = self.in_proj(x).transpose(-1, -2) B, C, x = BCx.chunk(3, dim=-2) Bx = B * x + # Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated + # by the current layer + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # In this case, the cache was already updated and we need to subtract seq_len to get the correct past length + if "full_attention" in self.config.layer_types[: self.layer_idx]: + past_seen_tokens = past_seen_tokens - seqlen + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.get_seq_length() > 0: + if past_seen_tokens > 0: conv_out = causal_conv1d_update( Bx.squeeze(-1), past_key_values.conv_cache[self.layer_idx], @@ -342,7 +350,7 @@ def slow_forward( if "full_attention" in self.config.layer_types[: self.layer_idx]: past_seen_tokens = past_seen_tokens - seqlen - if past_key_values is not None and past_seen_tokens > 0: + if past_seen_tokens > 0: conv_state = past_key_values.conv_cache[self.layer_idx] cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens cache_position = cache_position.clamp(0, self.L_cache - 1) diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index d6b0401e4658..03bdbbdc95f8 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -535,14 +535,22 @@ def cuda_kernels_forward( past_key_values: Lfm2MoeHybridConvCache | None = None, attention_mask: torch.Tensor | None = None, ): + seqlen = x.shape[1] x = apply_mask_to_padding_states(x, attention_mask) BCx = self.in_proj(x).transpose(-1, -2) B, C, x = BCx.chunk(3, dim=-2) Bx = B * x + # Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated + # by the current layer + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # In this case, the cache was already updated and we need to subtract seq_len to get the correct past length + if "full_attention" in self.config.layer_types[: self.layer_idx]: + past_seen_tokens = past_seen_tokens - seqlen + conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_values is not None and past_key_values.get_seq_length() > 0: + if past_seen_tokens > 0: conv_out = causal_conv1d_update( Bx.squeeze(-1), past_key_values.conv_cache[self.layer_idx], @@ -583,7 +591,7 @@ def slow_forward( if "full_attention" in self.config.layer_types[: self.layer_idx]: past_seen_tokens = past_seen_tokens - seqlen - if past_key_values is not None and past_seen_tokens > 0: + if past_seen_tokens > 0: conv_state = past_key_values.conv_cache[self.layer_idx] cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens cache_position = cache_position.clamp(0, self.L_cache - 1)