diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index bb547172894b..dd1276c0042b 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -37,6 +37,12 @@ def main(): parser.add_argument( "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." ) + parser.add_argument( + "--framework", + type=str, + default="pt", + help="The framework of the model weights, either `pt` for PyTorch or `tf` for TensorFlow. The specified framework must be installed.", + ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") # Retrieve CLI arguments @@ -48,7 +54,7 @@ def main(): # Allocate the model tokenizer = AutoTokenizer.from_pretrained(args.model) - model = FeaturesManager.get_model_from_feature(args.feature, args.model) + model = FeaturesManager.get_model_from_feature(args.feature, args.model, args.framework) model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) onnx_config = model_onnx_config(model.config) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index d21d1d3072fb..2a67f10e7d93 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -23,7 +23,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -if is_torch_available(): +if is_torch_available() and not is_tf_available(): from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, @@ -34,7 +34,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, ) -elif is_tf_available(): +elif is_tf_available() and not is_torch_available(): from transformers.models.auto import ( TFAutoModel, TFAutoModelForCausalLM, @@ -45,6 +45,25 @@ TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, ) +elif is_torch_available() and is_tf_available(): + from transformers.models.auto import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForMultipleChoice, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + TFAutoModel, + TFAutoModelForCausalLM, + TFAutoModelForMaskedLM, + TFAutoModelForMultipleChoice, + TFAutoModelForQuestionAnswering, + TFAutoModelForSeq2SeqLM, + TFAutoModelForSequenceClassification, + TFAutoModelForTokenClassification, + ) else: logger.warning( "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed." @@ -79,7 +98,7 @@ def supported_features_mapping( class FeaturesManager: - if is_torch_available(): + if is_torch_available() and not is_tf_available(): _TASKS_TO_AUTOMODELS = { "default": AutoModel, "masked-lm": AutoModelForMaskedLM, @@ -90,7 +109,7 @@ class FeaturesManager: "multiple-choice": AutoModelForMultipleChoice, "question-answering": AutoModelForQuestionAnswering, } - elif is_tf_available(): + elif is_tf_available() and not is_torch_available(): _TASKS_TO_AUTOMODELS = { "default": TFAutoModel, "masked-lm": TFAutoModelForMaskedLM, @@ -101,6 +120,20 @@ class FeaturesManager: "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, } + elif is_tf_available() and is_torch_available(): + _TASKS_TO_AUTOMODELS = { + "default": {"tf": TFAutoModel, "pt": AutoModel}, + "masked-lm": {"tf": TFAutoModelForMaskedLM, "pt": AutoModelForMaskedLM}, + "causal-lm": {"tf": TFAutoModelForCausalLM, "pt": AutoModelForCausalLM}, + "seq2seq-lm": {"tf": TFAutoModelForSeq2SeqLM, "pt": AutoModelForSeq2SeqLM}, + "sequence-classification": { + "tf": TFAutoModelForSequenceClassification, + "pt": AutoModelForSequenceClassification, + }, + "token-classification": {"tf": TFAutoModelForTokenClassification, "pt": AutoModelForTokenClassification}, + "multiple-choice": {"tf": TFAutoModelForMultipleChoice, "pt": AutoModelForMultipleChoice}, + "question-answering": {"tf": TFAutoModelForQuestionAnswering, "pt": AutoModelForQuestionAnswering}, + } else: _TASKS_TO_AUTOMODELS = {} @@ -273,7 +306,7 @@ def feature_to_task(feature: str) -> str: return feature.replace("-with-past", "") @staticmethod - def get_model_class_for_feature(feature: str) -> Type: + def get_model_class_for_feature(feature: str, framework: str = None) -> Type: """ Attempt to retrieve an AutoModel class from a feature name. @@ -289,9 +322,14 @@ def get_model_class_for_feature(feature: str) -> Type: f"Unknown task: {feature}. " f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" ) - return FeaturesManager._TASKS_TO_AUTOMODELS[task] + if framework: + return FeaturesManager._TASKS_TO_AUTOMODELS[task][framework] + else: + return FeaturesManager._TASKS_TO_AUTOMODELS[task] - def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]: + def get_model_from_feature( + feature: str, model: str, framework: str = None + ) -> Union[PreTrainedModel, TFPreTrainedModel]: """ Attempt to retrieve a model from a model's name and the feature to be enabled. @@ -303,7 +341,7 @@ def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, T The instance of the model. """ - model_class = FeaturesManager.get_model_class_for_feature(feature) + model_class = FeaturesManager.get_model_class_for_feature(feature, framework) return model_class.from_pretrained(model) @staticmethod