Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/transformers/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def main():
parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework",
type=str,
default="pt",
help="The framework of the model weights, either `pt` for PyTorch or `tf` for TensorFlow. The specified framework must be installed.",
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")

# Retrieve CLI arguments
Expand All @@ -48,7 +54,7 @@ def main():

# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model)
model = FeaturesManager.get_model_from_feature(args.feature, args.model, args.framework)
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
onnx_config = model_onnx_config(model.config)

Expand Down
54 changes: 46 additions & 8 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

if is_torch_available():
if is_torch_available() and not is_tf_available():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra condition is used to check if we're in a pure torch environment

from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
Expand All @@ -34,7 +34,7 @@
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
elif is_tf_available():
elif is_tf_available() and not is_torch_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the whole logic of having three tests can be simplified if you just change that elif to a simple if.

from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
Expand All @@ -45,6 +45,25 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
elif is_torch_available() and is_tf_available():
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
TFAutoModelForSeq2SeqLM,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
)
else:
logger.warning(
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
Expand Down Expand Up @@ -79,7 +98,7 @@ def supported_features_mapping(


class FeaturesManager:
if is_torch_available():
if is_torch_available() and not is_tf_available():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a bit of duplicate logic in this module - perhaps the autoclass imports above should moved directly within FeaturesManager?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, instead of having three tests, why not always have _TASKS_TO_AUTOMODELS be a nested dict with frameworks, and you then fill the frameworks when each framework if available?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea - thanks! In the end we may not need this if we adopt solution 2 :)

_TASKS_TO_AUTOMODELS = {
"default": AutoModel,
"masked-lm": AutoModelForMaskedLM,
Expand All @@ -90,7 +109,7 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
}
elif is_tf_available():
elif is_tf_available() and not is_torch_available():
_TASKS_TO_AUTOMODELS = {
"default": TFAutoModel,
"masked-lm": TFAutoModelForMaskedLM,
Expand All @@ -101,6 +120,20 @@ class FeaturesManager:
"multiple-choice": TFAutoModelForMultipleChoice,
"question-answering": TFAutoModelForQuestionAnswering,
}
elif is_tf_available() and is_torch_available():
_TASKS_TO_AUTOMODELS = {
"default": {"tf": TFAutoModel, "pt": AutoModel},
"masked-lm": {"tf": TFAutoModelForMaskedLM, "pt": AutoModelForMaskedLM},
"causal-lm": {"tf": TFAutoModelForCausalLM, "pt": AutoModelForCausalLM},
"seq2seq-lm": {"tf": TFAutoModelForSeq2SeqLM, "pt": AutoModelForSeq2SeqLM},
"sequence-classification": {
"tf": TFAutoModelForSequenceClassification,
"pt": AutoModelForSequenceClassification,
},
"token-classification": {"tf": TFAutoModelForTokenClassification, "pt": AutoModelForTokenClassification},
"multiple-choice": {"tf": TFAutoModelForMultipleChoice, "pt": AutoModelForMultipleChoice},
"question-answering": {"tf": TFAutoModelForQuestionAnswering, "pt": AutoModelForQuestionAnswering},
}
else:
_TASKS_TO_AUTOMODELS = {}

Expand Down Expand Up @@ -273,7 +306,7 @@ def feature_to_task(feature: str) -> str:
return feature.replace("-with-past", "")

@staticmethod
def get_model_class_for_feature(feature: str) -> Type:
def get_model_class_for_feature(feature: str, framework: str = None) -> Type:
"""
Attempt to retrieve an AutoModel class from a feature name.

Expand All @@ -289,9 +322,14 @@ def get_model_class_for_feature(feature: str) -> Type:
f"Unknown task: {feature}. "
f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
)
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
if framework:
return FeaturesManager._TASKS_TO_AUTOMODELS[task][framework]
else:
return FeaturesManager._TASKS_TO_AUTOMODELS[task]

def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
def get_model_from_feature(
feature: str, model: str, framework: str = None
) -> Union[PreTrainedModel, TFPreTrainedModel]:
"""
Attempt to retrieve a model from a model's name and the feature to be enabled.

Expand All @@ -303,7 +341,7 @@ def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, T
The instance of the model.

"""
model_class = FeaturesManager.get_model_class_for_feature(feature)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
return model_class.from_pretrained(model)

@staticmethod
Expand Down