Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,11 +882,51 @@ def call_module(self, m, forward, args, kwargs):
def proxy(self, node):
return HFProxy(node, self)

def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
"""
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
`torch.nn.Module` instance to use as the root and add embedded constants to.

Args:
root (`torch.nn.Module` or `Callable`):
Either a `torch.nn.Module`` or a function to be traced through. If root is not a
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
concrete_args (`Dict[str, Any], *optional*):
Concrete arguments that should not be treated as Proxies
dummy_inputs (`Dict[str, Any]`, *optional*):
The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.

Returns:
`torch.fx.Graph`:
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.

"""
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)

if concrete_args is None:
concrete_args = {}

sig = inspect.signature(root.forward)
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
for param in sig.parameters.values():
if param.name in dummy_inputs:
continue
if param.default is inspect.Parameter.empty:
raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the parameter doesn't have default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then tracing will most likely fail afterwards, added a check to fail early, as you suggested.


input_names = sig.parameters.keys() - concrete_args.keys()

# Creating a random input shape to generate dummy inputs.
Expand All @@ -898,11 +938,24 @@ def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] =
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)

inputs = {}
inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
for input_name in input_names:
inputs.update(self._generate_dummy_input(root, input_name, shape))
if input_name in inputs:
continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input.
if isinstance(root, PreTrainedModel) or type(root).__qualname__.startswith("_deserialize_graph_module"):
inputs.update(self._generate_dummy_input(root, input_name, shape))
else:
raise RuntimeError(
f"Could not generate input named {input_name} for because root is not a"
" transformers.PreTrainedModel."
)

concrete_metas = {input_name: input_.to("meta") for input_name, input_ in inputs.items()}
concrete_metas = {
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
for input_name, input_ in inputs.items()
}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
Expand Down