Skip to content

Llama model cannot be fx traced  #29923

@xuzifei-dmatrix

Description

@xuzifei-dmatrix

System Info

  • transformers version: 4.39.1
  • Platform: Linux-5.15.0-1057-azure-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.22.1
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: no
    - use_cpu: False
    - debug: False
    - num_processes: 4
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.2.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers.utils.fx import symbolic_trace
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
gm = symbolic_trace(model)
File ~/miniconda3/envs/mltools/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:665, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    661     causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
    663 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
    664 # Reference: https://github.com/pytorch/pytorch/issues/112577.
--> 665 if query_states.device.type == "cuda" and causal_mask is not None:
    666     query_states = query_states.contiguous()
    667     key_states = key_states.contiguous()

File ~/miniconda3/envs/mltools/lib/python3.10/site-packages/transformers/utils/fx.py:653, in HFProxy.__bool__(self)
    651 if hasattr(self, "_metadata") and self._metadata is not None:
    652     return self._metadata
--> 653 return super().__bool__()

File ~/miniconda3/envs/mltools/lib/python3.10/site-packages/torch/fx/proxy.py:441, in Proxy.__bool__(self)
    438             self.tracer.create_proxy('call_function', assert_fn, (self,), {})
    439             return True
--> 441 return self.tracer.to_bool(self)

File ~/miniconda3/envs/mltools/lib/python3.10/site-packages/torch/fx/proxy.py:301, in TracerBase.to_bool(self, obj)
    294 @compatibility(is_backward_compatible=True)
    295 def to_bool(self, obj: 'Proxy') -> bool:
    296     """Called when a proxy object is being converted to a boolean, such as
    297     when used in control flow.  Normally we don't know what to do because
    298     we don't know the value of the proxy, but a custom tracer can attach more
    299     information to the graph node using create_node and can choose to return a value.
    300     """
--> 301     raise TraceError('symbolically traced variables cannot be used as inputs to control flow')

TraceError: symbolically traced variables cannot be used as inputs to control flow

Expected behavior

It should trace with no error, which is the case for v4.38.2

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions