Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/transformers/models/lfm2/modeling_lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2HybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_seen_tokens > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down Expand Up @@ -507,7 +515,7 @@ def slow_forward(
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

if past_key_values is not None and past_seen_tokens > 0:
if past_seen_tokens > 0:
conv_state = past_key_values.conv_cache[self.layer_idx]
cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens
cache_position = cache_position.clamp(0, self.L_cache - 1)
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/lfm2/modular_lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2HybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_seen_tokens > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down Expand Up @@ -342,7 +350,7 @@ def slow_forward(
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

if past_key_values is not None and past_seen_tokens > 0:
if past_seen_tokens > 0:
conv_state = past_key_values.conv_cache[self.layer_idx]
cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens
cache_position = cache_position.clamp(0, self.L_cache - 1)
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/lfm2_moe/modeling_lfm2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,22 @@ def cuda_kernels_forward(
past_key_values: Lfm2MoeHybridConvCache | None = None,
attention_mask: torch.Tensor | None = None,
):
seqlen = x.shape[1]
x = apply_mask_to_padding_states(x, attention_mask)
BCx = self.in_proj(x).transpose(-1, -2)
B, C, x = BCx.chunk(3, dim=-2)

Bx = B * x

# Note: we may or may not have to substract the current seq_len here as the cache may or may not be already updated
# by the current layer
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
# In this case, the cache was already updated and we need to subtract seq_len to get the correct past length
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if past_key_values is not None and past_key_values.get_seq_length() > 0:
if past_seen_tokens > 0:
conv_out = causal_conv1d_update(
Bx.squeeze(-1),
past_key_values.conv_cache[self.layer_idx],
Expand Down Expand Up @@ -583,7 +591,7 @@ def slow_forward(
if "full_attention" in self.config.layer_types[: self.layer_idx]:
past_seen_tokens = past_seen_tokens - seqlen

if past_key_values is not None and past_seen_tokens > 0:
if past_seen_tokens > 0:
conv_state = past_key_values.conv_cache[self.layer_idx]
cache_position = torch.arange(seqlen, device=conv_state.device) + past_seen_tokens
cache_position = cache_position.clamp(0, self.L_cache - 1)
Expand Down
Loading