-
Notifications
You must be signed in to change notification settings - Fork 31.6k
[RFC] Add framework argument to ONNX export #15620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
| if is_torch_available(): | ||
| if is_torch_available() and not is_tf_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra condition is used to check if we're in a pure torch environment
|
|
||
| class FeaturesManager: | ||
| if is_torch_available(): | ||
| if is_torch_available() and not is_tf_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bit of duplicate logic in this module - perhaps the autoclass imports above should moved directly within FeaturesManager?
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally think the solution 2 would be better for the user, as it "just works". We can investigate the provenance of the error log and try to remove it if it's an issue, but it would be better than adding a new arg :-)
| AutoModelForTokenClassification, | ||
| ) | ||
| elif is_tf_available(): | ||
| elif is_tf_available() and not is_torch_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the whole logic of having three tests can be simplified if you just change that elif to a simple if.
|
|
||
| class FeaturesManager: | ||
| if is_torch_available(): | ||
| if is_torch_available() and not is_tf_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, instead of having three tests, why not always have _TASKS_TO_AUTOMODELS be a nested dict with frameworks, and you then fill the frameworks when each framework if available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a nice idea - thanks! In the end we may not need this if we adopt solution 2 :)
Thanks for the feedback @sgugger ❤️ ! Having thought about it a bit more, I agree that solution 2 is the simplest and less error-prone: I've opened a PR for this here #15625 |
What does this PR do?
This PR addresses an edge case introduced by #13831 where the ONNX export fails if:
torchandtensorfloware installed in the same environmentHere is an example that fails to export on the
masterbranch:Traceback
The reason this fails is because the
FeaturesManager.get_model_class_for_feature()method uses the_TASKS_TO_AUTOMODELSmapping to determine which autoclass (e.gAutoModelvsTFAutoModel) to return for a given task. This mapping relies on the following branching logic:As a result, if a user has
torchandtensorflowinstalled, we return anAutoModelclass instead of the desiredTFAutoModelclass. In particular, Colab users cannot export pure TensorFlow models becausetorchis installed by default.Proposal
To address this issue, I've introduced a new
frameworkargument in the ONNX CLI and extended_TASKS_TO_AUTOMODELSto be a nesteddictwhen both frameworks are installed. With this change, one can now export pure TensorFlow models with:Similarly, pure PyTorch models can be exported as follows:
And checkpoints with both sets of weights also works:
Although the implementation works, I'm not entirely happy with it because
_TASKS_TO_AUTOMODELSchanges (flat vs nested) depending on the installation environment, and this feels hacky.Alternative solution 1
Thanks to a tip from @stas00, one solution is to change nothing and get the user to specify which framework they're using as an environment variable, e.g.
If we adopt this approach, we could provide a warning when both
torchandtensorfloware installed and suggest an example like the one above.Alternative solution 2
It occurred to me that we can solve this with a simple try/except in
FeaturesManager.get_model_from_feature()as follows:The user will still see a 404 error in the logs
python -m transformers.onnx --model=keras-io/transformers-qa onnx/ # 404 Client Error: Entry Not Found for url: https://huggingface.co/keras-io/transformers-qa/resolve/main/pytorch_model.binbut the conversion to ONNX will work once the TensorFlow weights are loaded in the
AutoModelinstance. Note: this solution seems to be similar to the one adopted in thepipeline()function, e.g.The advantage of this approach is that the user doesn't have to manually specify a
--frameworkarg, i.e. it "just works". The only drawback I see is that there might be differences between thetorch.onnxandtf2onnxpackages used for the ONNX export, and by usingtorch.onnxas the default we may mislead users on where to debug their exports. However, this is probably a rare case and could be revisited if users report problems.Feedback on which approach is preferred is much appreciated!