diff --git a/docs/source/serialization.mdx b/docs/source/serialization.mdx index b985fabaa2c2..8e540ca024ea 100644 --- a/docs/source/serialization.mdx +++ b/docs/source/serialization.mdx @@ -61,10 +61,6 @@ Ready-made configurations include the following architectures: - XLM-RoBERTa - XLM-RoBERTa-XL -The ONNX conversion is supported for the PyTorch versions of the models. If you -would like to be able to convert a TensorFlow model, please let us know by -opening an issue. - In the next two sections, we'll show you how to: * Export a supported model using the `transformers.onnx` package. @@ -149,6 +145,8 @@ DistilBERT we have: ["last_hidden_state"] ``` +The approach is similar for TensorFlow models. + ### Selecting features for different model topologies Each ready-made configuration comes with a set of _features_ that enable you to diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 18542f83f8e3..f66c0b61ddc0 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -21,7 +21,7 @@ from packaging.version import Version, parse from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available -from transformers.file_utils import is_torch_onnx_dict_inputs_support_available +from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available from transformers.onnx.config import OnnxConfig from transformers.utils import logging @@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version): ) -def export( - tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path +def export_pytorch( + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, + config: OnnxConfig, + opset: int, + output: Path, ) -> Tuple[List[str], List[str]]: """ - Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR + Export a PyTorch model to an ONNX Intermediate Representation (IR) Args: - tokenizer: - model: - config: - opset: - output: + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model to export. + config ([`~onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + output (`Path`): + Directory to store the exported ONNX model. Returns: + `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration. + """ + if issubclass(type(model), PreTrainedModel): + import torch + from torch.onnx import export as onnx_export + + logger.info(f"Using framework PyTorch: {torch.__version__}") + with torch.no_grad(): + model.config.return_dict = True + model.eval() + + # Check if we need to override certain configuration item + if config.values_override is not None: + logger.info(f"Overriding {len(config.values_override)} configuration item(s)") + for override_config_key, override_config_value in config.values_override.items(): + logger.info(f"\t- {override_config_key} -> {override_config_value}") + setattr(model.config, override_config_key, override_config_value) + + # Ensure inputs match + # TODO: Check when exporting QA we provide "is_pair=True" + model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) + onnx_outputs = list(config.outputs.keys()) + + if not inputs_match: + raise ValueError("Model and config inputs doesn't match") + + config.patch_ops() + + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if parse(torch.__version__) <= parse("1.10.99"): + # export can work with named args but the dict containing named args + # has to be the last element of the args tuple. + onnx_export( + model, + (model_inputs,), + f=output.as_posix(), + input_names=list(config.inputs.keys()), + output_names=onnx_outputs, + dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, + do_constant_folding=True, + use_external_data_format=config.use_external_data_format(model.num_parameters()), + enable_onnx_checker=True, + opset_version=opset, + ) + else: + onnx_export( + model, + (model_inputs,), + f=output.as_posix(), + input_names=list(config.inputs.keys()), + output_names=onnx_outputs, + dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, + do_constant_folding=True, + opset_version=opset, + ) + + config.restore_ops() + + return matched_inputs, onnx_outputs + +def export_tensorflow( + tokenizer: PreTrainedTokenizer, + model: TFPreTrainedModel, + config: OnnxConfig, + opset: int, + output: Path, +) -> Tuple[List[str], List[str]]: """ - if not is_torch_available(): - raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.") - - import torch - from torch.onnx import export - - from ..file_utils import torch_version - - if not is_torch_onnx_dict_inputs_support_available(): - raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") - - logger.info(f"Using framework PyTorch: {torch.__version__}") - with torch.no_grad(): - model.config.return_dict = True - model.eval() - - # Check if we need to override certain configuration item - if config.values_override is not None: - logger.info(f"Overriding {len(config.values_override)} configuration item(s)") - for override_config_key, override_config_value in config.values_override.items(): - logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) - - # Ensure inputs match - # TODO: Check when exporting QA we provide "is_pair=True" - model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) - inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) - onnx_outputs = list(config.outputs.keys()) - - if not inputs_match: - raise ValueError("Model and config inputs doesn't match") - - config.patch_ops() - - # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, - # so we check the torch version for backwards compatibility - if parse(torch.__version__) <= parse("1.10.99"): - # export can work with named args but the dict containing named args - # has to be the last element of the args tuple. - export( - model, - (model_inputs,), - f=output.as_posix(), - input_names=list(config.inputs.keys()), - output_names=onnx_outputs, - dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, - do_constant_folding=True, - use_external_data_format=config.use_external_data_format(model.num_parameters()), - enable_onnx_checker=True, - opset_version=opset, - ) - else: - export( - model, - (model_inputs,), - f=output.as_posix(), - input_names=list(config.inputs.keys()), - output_names=onnx_outputs, - dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, - do_constant_folding=True, - opset_version=opset, - ) + Export a TensorFlow model to an ONNX Intermediate Representation (IR) + + Args: + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. + model ([`TFPreTrainedModel`]): + The model to export. + config ([`~onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + output (`Path`): + Directory to store the exported ONNX model. + + Returns: + `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration. + """ + import tensorflow as tf + + import onnx + import tf2onnx + + model.config.return_dict = True + + # Check if we need to override certain configuration item + if config.values_override is not None: + logger.info(f"Overriding {len(config.values_override)} configuration item(s)") + for override_config_key, override_config_value in config.values_override.items(): + logger.info(f"\t- {override_config_key} -> {override_config_value}") + setattr(model.config, override_config_key, override_config_value) - config.restore_ops() + # Ensure inputs match + model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) + inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) + onnx_outputs = list(config.outputs.keys()) + + input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()] + onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset) + onnx.save(onnx_model, output.as_posix()) + config.restore_ops() return matched_inputs, onnx_outputs +def export( + tokenizer: PreTrainedTokenizer, + model: Union[PreTrainedModel, TFPreTrainedModel], + config: OnnxConfig, + opset: int, + output: Path, +) -> Tuple[List[str], List[str]]: + """ + Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) + + Args: + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model to export. + config ([`~onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + output (`Path`): + Directory to store the exported ONNX model. + + Returns: + `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from + the ONNX configuration. + """ + if not (is_torch_available() or is_tf_available()): + raise ImportError( + "Cannot convert because neither PyTorch nor TensorFlow are not installed. " + "Please install torch or tensorflow first." + ) + + if is_torch_available(): + from transformers.file_utils import torch_version + + if not is_torch_onnx_dict_inputs_support_available(): + raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") + + if is_torch_available() and issubclass(type(model), PreTrainedModel): + return export_pytorch(tokenizer, model, config, opset, output) + elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): + return export_tensorflow(tokenizer, model, config, opset, output) + + def validate_model_outputs( config: OnnxConfig, tokenizer: PreTrainedTokenizer, @@ -160,7 +260,10 @@ def validate_model_outputs( # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test # dynamic input shapes. - reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + if issubclass(type(reference_model), PreTrainedModel): + reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + else: + reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) # Create ONNX Runtime session options = SessionOptions() @@ -210,7 +313,10 @@ def validate_model_outputs( # Check the shape and values match for name, ort_value in zip(onnx_named_outputs, onnx_outputs): - ref_value = ref_outputs_dict[name].detach().numpy() + if issubclass(type(reference_model), PreTrainedModel): + ref_value = ref_outputs_dict[name].detach().numpy() + else: + ref_value = ref_outputs_dict[name].numpy() logger.info(f'\t- Validating ONNX Model output "{name}":') # Shape @@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match( :param model_inputs: :param config_inputs: :return: """ - forward_parameters = signature(model.forward).parameters + if issubclass(type(model), PreTrainedModel): + forward_parameters = signature(model.forward).parameters + else: + forward_parameters = signature(model.call).parameters model_inputs_set = set(model_inputs) # We are fine if config_inputs has more keys than model_inputs diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 1020387592a2..b2f4900b9d20 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,7 +1,7 @@ from functools import partial, reduce -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type, Union -from .. import PretrainedConfig, is_torch_available +from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig from ..models.bert import BertOnnxConfig @@ -23,7 +23,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_torch_available(): - from transformers import PreTrainedModel from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, @@ -34,9 +33,20 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, ) +elif is_tf_available(): + from transformers.models.auto import ( + TFAutoModel, + TFAutoModelForCausalLM, + TFAutoModelForMaskedLM, + TFAutoModelForMultipleChoice, + TFAutoModelForQuestionAnswering, + TFAutoModelForSeq2SeqLM, + TFAutoModelForSequenceClassification, + TFAutoModelForTokenClassification, + ) else: logger.warning( - "The ONNX export features are only supported for PyTorch, you will not be able to export models without it." + "The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed." ) @@ -79,6 +89,17 @@ class FeaturesManager: "multiple-choice": AutoModelForMultipleChoice, "question-answering": AutoModelForQuestionAnswering, } + elif is_tf_available(): + _TASKS_TO_AUTOMODELS = { + "default": TFAutoModel, + "masked-lm": TFAutoModelForMaskedLM, + "causal-lm": TFAutoModelForCausalLM, + "seq2seq-lm": TFAutoModelForSeq2SeqLM, + "sequence-classification": TFAutoModelForSequenceClassification, + "token-classification": TFAutoModelForTokenClassification, + "multiple-choice": TFAutoModelForMultipleChoice, + "question-answering": TFAutoModelForQuestionAnswering, + } else: _TASKS_TO_AUTOMODELS = {} @@ -260,7 +281,7 @@ def get_model_class_for_feature(feature: str) -> Type: ) return FeaturesManager._TASKS_TO_AUTOMODELS[task] - def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: + def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]: """ Attempt to retrieve a model from a model's name and the feature to be enabled. @@ -276,7 +297,9 @@ def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: return model_class.from_pretrained(model) @staticmethod - def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: + def check_supported_model_or_raise( + model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default" + ) -> Tuple[str, Callable]: """ Check whether or not the model has the requested features. diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index c1594310272d..5f41da6f97a5 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -4,7 +4,7 @@ from unittest.mock import patch from parameterized import parameterized -from transformers import AutoConfig, AutoTokenizer, is_torch_available +from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available from transformers.onnx import ( EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, @@ -15,11 +15,11 @@ from transformers.onnx.config import OnnxConfigWithPast -if is_torch_available(): +if is_torch_available() or is_tf_available(): from transformers.onnx.features import FeaturesManager from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size -from transformers.testing_utils import require_onnx, require_torch, slow +from transformers.testing_utils import require_onnx, require_tf, require_torch, slow @require_onnx @@ -191,19 +191,44 @@ def test_values_override(self): ("marian", "Helsinki-NLP/opus-mt-en-de"), } +TENSORFLOW_EXPORT_DEFAULT_MODELS = { + ("albert", "hf-internal-testing/tiny-albert"), + ("bert", "bert-base-cased"), + ("ibert", "kssteven/ibert-roberta-base"), + ("camembert", "camembert-base"), + ("distilbert", "distilbert-base-cased"), + ("roberta", "roberta-base"), + ("xlm-roberta", "xlm-roberta-base"), + ("layoutlm", "microsoft/layoutlm-base-uncased"), +} + +TENSORFLOW_EXPORT_WITH_PAST_MODELS = { + ("gpt2", "gpt2"), + ("gpt-neo", "EleutherAI/gpt-neo-125M"), +} + +TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { + ("bart", "facebook/bart-base"), + ("mbart", "sshleifer/tiny-mbart"), + ("t5", "t5-small"), + ("marian", "Helsinki-NLP/opus-mt-en-de"), +} + def _get_models_to_test(export_models_list): models_to_test = [] - if not is_torch_available(): - # Returning some dummy test that should not be ever called because of the @require_torch decorator. + if is_torch_available() or is_tf_available(): + for (name, model) in export_models_list: + for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type( + name + ).items(): + models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor)) + return sorted(models_to_test) + else: + # Returning some dummy test that should not be ever called because of the @require_torch / @require_tf + # decorators. # The reason for not returning an empty list is because parameterized.expand complains when it's empty. return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)] - for (name, model) in export_models_list: - for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type( - name - ).items(): - models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor)) - return sorted(models_to_test) class OnnxExportTestCaseV2(TestCase): @@ -211,7 +236,7 @@ class OnnxExportTestCaseV2(TestCase): Integration tests ensuring supported models are correctly exported """ - def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): + def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): from transformers.onnx import export tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -245,13 +270,13 @@ def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_clas @slow @require_torch def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) @slow @require_torch def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS)) @slow @@ -259,4 +284,24 @@ def test_pytorch_export_with_past(self, test_name, name, model_name, feature, on def test_pytorch_export_seq2seq_with_past( self, test_name, name, model_name, feature, onnx_config_class_constructor ): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) + + @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS)) + @slow + @require_tf + def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) + + @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS)) + @slow + @require_tf + def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) + + @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS)) + @slow + @require_tf + def test_tensorflow_export_seq2seq_with_past( + self, test_name, name, model_name, feature, onnx_config_class_constructor + ): + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) diff --git a/utils/check_table.py b/utils/check_table.py index 449ad02c3b21..9d948fbb6d9f 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False): current_list, start_index, end_index, lines = _find_text_in_file( filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"), start_prompt="", - end_prompt="The ONNX conversion is supported for the PyTorch versions of the models.", + end_prompt="In the next two sections, we'll show you how to:", ) new_list = get_onnx_model_list()