Skip to content

Commit 17cd7a9

Browse files
Fix torch.fx symbolic tracing for LLama (#30047)
* [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * [WIP] fix fx * Apply changes to other models
1 parent 4879531 commit 17cd7a9

File tree

6 files changed

+23
-18
lines changed

6 files changed

+23
-18
lines changed

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,9 @@ def forward(
908908
if position_ids is None:
909909
position_ids = cache_position.unsqueeze(0)
910910

911-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
911+
causal_mask = self._update_causal_mask(
912+
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
913+
)
912914

913915
# embed positions
914916
hidden_states = inputs_embeds
@@ -976,7 +978,7 @@ def forward(
976978
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
977979
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
978980
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
979-
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
981+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
980982
if self.config._attn_implementation == "flash_attention_2":
981983
if attention_mask is not None and 0.0 in attention_mask:
982984
return attention_mask
@@ -989,7 +991,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
989991
target_length = self.config.max_position_embeddings
990992
else: # dynamic cache
991993
target_length = (
992-
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
994+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
993995
)
994996

995997
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,9 @@ def forward(
888888
if position_ids is None:
889889
position_ids = cache_position.unsqueeze(0)
890890

891-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
891+
causal_mask = self._update_causal_mask(
892+
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
893+
)
892894

893895
# embed positions
894896
hidden_states = inputs_embeds
@@ -962,7 +964,7 @@ def forward(
962964
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
963965
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
964966
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
965-
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
967+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
966968
if self.config._attn_implementation == "flash_attention_2":
967969
if attention_mask is not None and 0.0 in attention_mask:
968970
return attention_mask
@@ -975,7 +977,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
975977
target_length = self.config.max_position_embeddings
976978
else: # dynamic cache
977979
target_length = (
978-
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
980+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
979981
)
980982

981983
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

src/transformers/models/llama/modeling_llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,9 @@ def forward(
987987
if position_ids is None:
988988
position_ids = cache_position.unsqueeze(0)
989989

990-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
990+
causal_mask = self._update_causal_mask(
991+
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
992+
)
991993

992994
# embed positions
993995
hidden_states = inputs_embeds
@@ -1055,7 +1057,7 @@ def forward(
10551057
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
10561058
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
10571059
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1058-
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1060+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
10591061
if self.config._attn_implementation == "flash_attention_2":
10601062
if attention_mask is not None and 0.0 in attention_mask:
10611063
return attention_mask
@@ -1068,7 +1070,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
10681070
target_length = self.config.max_position_embeddings
10691071
else: # dynamic cache
10701072
target_length = (
1071-
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
1073+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
10721074
)
10731075

10741076
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

src/transformers/utils/fx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,14 @@ def torch_arange(*args, **kwargs):
260260

261261
def torch_full(*args, **kwargs):
262262
args = list(args)
263-
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
264-
args[1] = 1 # Any value.
263+
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
264+
if len(args) > 1:
265+
args[1] = 1
266+
else:
267+
kwargs["fill_value"] = 1
265268
kwargs_without_device = dict(kwargs)
266269
kwargs_without_device.pop("device", None)
267-
return torch.full(*args, **kwargs_without_device)
270+
return torch.full(*args, **kwargs_without_device, device="meta")
268271

269272

270273
def torch_cat(tensors, dim=None, axis=None, *, out=None):

tests/models/cohere/test_modeling_cohere.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
283283
)
284284
test_headmasking = False
285285
test_pruning = False
286-
fx_compatible = (
287-
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
288-
)
286+
fx_compatible = True
289287

290288
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
291289
# This is because we are hitting edge cases with the causal_mask buffer

tests/models/llama/test_modeling_llama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
305305
)
306306
test_headmasking = False
307307
test_pruning = False
308-
fx_compatible = (
309-
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
310-
)
308+
fx_compatible = True
311309

312310
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
313311
# This is because we are hitting edge cases with the causal_mask buffer

0 commit comments

Comments
 (0)