Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -976,7 +978,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -989,7 +991,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,9 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -962,7 +964,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -975,7 +977,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,9 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess passing cache_position[-1] doesn't work? The only issue is that past_seen_tokens s gonna be deprecated in favor of cache positions, but that works since with static cache we use the max positions embeddings

)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1055,7 +1057,7 @@ def forward(
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
Expand All @@ -1068,7 +1070,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)

causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,14 @@ def torch_arange(*args, **kwargs):

def torch_full(*args, **kwargs):
args = list(args)
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
args[1] = 1 # Any value.
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
if len(args) > 1:
args[1] = 1
else:
kwargs["fill_value"] = 1
kwargs_without_device = dict(kwargs)
kwargs_without_device.pop("device", None)
return torch.full(*args, **kwargs_without_device)
return torch.full(*args, **kwargs_without_device, device="meta")


def torch_cat(tensors, dim=None, axis=None, *, out=None):
Expand Down
4 changes: 1 addition & 3 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True

# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
Expand Down
4 changes: 1 addition & 3 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True

# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
Expand Down