Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/transformers/models/lfm2/modeling_lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2HybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_key_values is not None and past_seen_tokens > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/lfm2/modular_lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2HybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_key_values is not None and past_seen_tokens > 0:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

past_key_values is not None will always be true if past_seen_tokens > 0

Suggested change
if past_key_values is not None and past_seen_tokens > 0:
if past_seen_tokens > 0:

conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2MoeHybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_key_values is not None and past_seen_tokens > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down
Loading