HFTracer.trace can now take callables and torch.nn.Module#18457
HFTracer.trace can now take callables and torch.nn.Module#18457michaelbenayoun merged 4 commits intohuggingface:mainfrom
Conversation
… and torch.nn.Module in general
|
The documentation is not available anymore as the PR was closed or merged. |
thomasw21
left a comment
There was a problem hiding this comment.
Nice! Some small comment!
src/transformers/utils/fx.py
Outdated
| The dummy inputs needed to handle data-dependent control-flow if `root` is not a | ||
| [`~transformers.PreTrainedModel`]. It can also be used when `root` is a | ||
| [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs. | ||
| infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
| infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): | |
| infer_concrete_args_from_dummy_inputs (`bool`, defaults to `True`): |
There was a problem hiding this comment.
No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.
There was a problem hiding this comment.
I was unaware of this. I mistakenly thought that bool, *optional* means Optional[bool] in typing nomenclature. This seems weird to me as all defaults become optional now? Anyway perhaps a chat we can have someplace else.
There was a problem hiding this comment.
All arguments that have a default are optionals, yes. That is the definition of an optional argument, an argument you do not need to provide. It's not because the typing module decide to (badly) reuse that word for something else that this will change.
|
|
||
| sig = inspect.signature(root.forward) | ||
| if dummy_inputs is not None and infer_concrete_args_from_dummy_inputs: | ||
| concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs}) |
There was a problem hiding this comment.
What happens if the parameter doesn't have default?
There was a problem hiding this comment.
Then tracing will most likely fail afterwards, added a check to fail early, as you suggested.
sgugger
left a comment
There was a problem hiding this comment.
Thanks for adding this!
src/transformers/utils/fx.py
Outdated
| The dummy inputs needed to handle data-dependent control-flow if `root` is not a | ||
| [`~transformers.PreTrainedModel`]. It can also be used when `root` is a | ||
| [`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs. | ||
| infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.
What does this PR do?
This PR enables to use the
HFTracer"meta-tracing" features to trace any Python callable /torch.nn.Module.For
transformers.PreTrainedModels, the methodHFTracer._generate_dummy_inputsalready takes care of creating the original dummy inputs needed to handle data-dependent control-flow in the forward pass.Now, the user can specify
dummy_inputsdirectly to theHFTracer.tracemethod in order to be able to trace other things thantransformers.PreTrainedModels. This is useful for pattern matching for instance.This becomes possible:
By default, if
dummy_inputsis specified, every argument torootthat is not indummy_inputswill be considered a concrete arg (and thus added toconcrete_args). You can disable that by settinginfer_concrete_args_from_dummy_inputstoFalse. This is useful if want to provide custom dummy inputs for some inputs, while still keeping theHFTracer._generate_dummy_inputsdoing the work for other inputs (provided thatrootis atransformers.PreTrainedModelsince only this case is supported for automatic dummy inputs generation).