-
Notifications
You must be signed in to change notification settings - Fork 31.6k
[RFC] Add framework argument to ONNX export #15620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
| if is_torch_available(): | ||
| if is_torch_available() and not is_tf_available(): | ||
| from transformers.models.auto import ( | ||
| AutoModel, | ||
| AutoModelForCausalLM, | ||
|
|
@@ -34,7 +34,7 @@ | |
| AutoModelForSequenceClassification, | ||
| AutoModelForTokenClassification, | ||
| ) | ||
| elif is_tf_available(): | ||
| elif is_tf_available() and not is_torch_available(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the whole logic of having three tests can be simplified if you just change that |
||
| from transformers.models.auto import ( | ||
| TFAutoModel, | ||
| TFAutoModelForCausalLM, | ||
|
|
@@ -45,6 +45,25 @@ | |
| TFAutoModelForSequenceClassification, | ||
| TFAutoModelForTokenClassification, | ||
| ) | ||
| elif is_torch_available() and is_tf_available(): | ||
| from transformers.models.auto import ( | ||
| AutoModel, | ||
| AutoModelForCausalLM, | ||
| AutoModelForMaskedLM, | ||
| AutoModelForMultipleChoice, | ||
| AutoModelForQuestionAnswering, | ||
| AutoModelForSeq2SeqLM, | ||
| AutoModelForSequenceClassification, | ||
| AutoModelForTokenClassification, | ||
| TFAutoModel, | ||
| TFAutoModelForCausalLM, | ||
| TFAutoModelForMaskedLM, | ||
| TFAutoModelForMultipleChoice, | ||
| TFAutoModelForQuestionAnswering, | ||
| TFAutoModelForSeq2SeqLM, | ||
| TFAutoModelForSequenceClassification, | ||
| TFAutoModelForTokenClassification, | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed." | ||
|
|
@@ -79,7 +98,7 @@ def supported_features_mapping( | |
|
|
||
|
|
||
| class FeaturesManager: | ||
| if is_torch_available(): | ||
| if is_torch_available() and not is_tf_available(): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a bit of duplicate logic in this module - perhaps the autoclass imports above should moved directly within
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, instead of having three tests, why not always have
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice idea - thanks! In the end we may not need this if we adopt solution 2 :) |
||
| _TASKS_TO_AUTOMODELS = { | ||
| "default": AutoModel, | ||
| "masked-lm": AutoModelForMaskedLM, | ||
|
|
@@ -90,7 +109,7 @@ class FeaturesManager: | |
| "multiple-choice": AutoModelForMultipleChoice, | ||
| "question-answering": AutoModelForQuestionAnswering, | ||
| } | ||
| elif is_tf_available(): | ||
| elif is_tf_available() and not is_torch_available(): | ||
| _TASKS_TO_AUTOMODELS = { | ||
| "default": TFAutoModel, | ||
| "masked-lm": TFAutoModelForMaskedLM, | ||
|
|
@@ -101,6 +120,20 @@ class FeaturesManager: | |
| "multiple-choice": TFAutoModelForMultipleChoice, | ||
| "question-answering": TFAutoModelForQuestionAnswering, | ||
| } | ||
| elif is_tf_available() and is_torch_available(): | ||
| _TASKS_TO_AUTOMODELS = { | ||
| "default": {"tf": TFAutoModel, "pt": AutoModel}, | ||
| "masked-lm": {"tf": TFAutoModelForMaskedLM, "pt": AutoModelForMaskedLM}, | ||
| "causal-lm": {"tf": TFAutoModelForCausalLM, "pt": AutoModelForCausalLM}, | ||
| "seq2seq-lm": {"tf": TFAutoModelForSeq2SeqLM, "pt": AutoModelForSeq2SeqLM}, | ||
| "sequence-classification": { | ||
| "tf": TFAutoModelForSequenceClassification, | ||
| "pt": AutoModelForSequenceClassification, | ||
| }, | ||
| "token-classification": {"tf": TFAutoModelForTokenClassification, "pt": AutoModelForTokenClassification}, | ||
| "multiple-choice": {"tf": TFAutoModelForMultipleChoice, "pt": AutoModelForMultipleChoice}, | ||
| "question-answering": {"tf": TFAutoModelForQuestionAnswering, "pt": AutoModelForQuestionAnswering}, | ||
| } | ||
| else: | ||
| _TASKS_TO_AUTOMODELS = {} | ||
|
|
||
|
|
@@ -273,7 +306,7 @@ def feature_to_task(feature: str) -> str: | |
| return feature.replace("-with-past", "") | ||
|
|
||
| @staticmethod | ||
| def get_model_class_for_feature(feature: str) -> Type: | ||
| def get_model_class_for_feature(feature: str, framework: str = None) -> Type: | ||
| """ | ||
| Attempt to retrieve an AutoModel class from a feature name. | ||
|
|
||
|
|
@@ -289,9 +322,14 @@ def get_model_class_for_feature(feature: str) -> Type: | |
| f"Unknown task: {feature}. " | ||
| f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" | ||
| ) | ||
| return FeaturesManager._TASKS_TO_AUTOMODELS[task] | ||
| if framework: | ||
| return FeaturesManager._TASKS_TO_AUTOMODELS[task][framework] | ||
| else: | ||
| return FeaturesManager._TASKS_TO_AUTOMODELS[task] | ||
|
|
||
| def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]: | ||
| def get_model_from_feature( | ||
| feature: str, model: str, framework: str = None | ||
| ) -> Union[PreTrainedModel, TFPreTrainedModel]: | ||
| """ | ||
| Attempt to retrieve a model from a model's name and the feature to be enabled. | ||
|
|
||
|
|
@@ -303,7 +341,7 @@ def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, T | |
| The instance of the model. | ||
|
|
||
| """ | ||
| model_class = FeaturesManager.get_model_class_for_feature(feature) | ||
| model_class = FeaturesManager.get_model_class_for_feature(feature, framework) | ||
| return model_class.from_pretrained(model) | ||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra condition is used to check if we're in a pure
torchenvironment