diff --git a/optimum/exporters/onnx/__init__.py b/optimum/exporters/onnx/__init__.py index 74cad88256..b9dadaa5e6 100644 --- a/optimum/exporters/onnx/__init__.py +++ b/optimum/exporters/onnx/__init__.py @@ -15,4 +15,9 @@ from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa from .config import TextDecoderOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig # noqa -from .convert import export, validate_model_outputs # noqa +from .convert import ( # noqa + export, + export_encoder_decoder_model, + validate_encoder_decoder_model_outputs, + validate_model_outputs, +) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 36c15fe94f..5b1c1d4239 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -22,7 +22,12 @@ from ...utils import logging from ..tasks import TasksManager from .base import OnnxConfigWithPast -from .convert import export, validate_model_outputs +from .convert import ( + export, + export_encoder_decoder_model, + validate_encoder_decoder_model_outputs, + validate_model_outputs, +) logger = logging.get_logger() # pylint: disable=invalid-name @@ -64,6 +69,14 @@ def main(): ), ) parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.") + parser.add_argument( + "--for-ort", + action="store_true", + help=( + "This exports models ready to be run with optimum.onnxruntime. Useful for encoder-decoder models for" + "conditional generation. If enabled the encoder and decoder of the model are exported separately." + ), + ) parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.") # Retrieve CLI arguments @@ -115,12 +128,17 @@ def main(): f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." ) - onnx_inputs, onnx_outputs = export( - model, - onnx_config, - args.opset, - args.output, - ) + if model.config.is_encoder_decoder and args.for_ort: + onnx_inputs, onnx_outputs = export_encoder_decoder_model( + model, + onnx_config, + args.opset, + args.output.parent.joinpath("encoder_model.onnx"), + args.output.parent.joinpath("decoder_model.onnx"), + args.output.parent.joinpath("decoder_with_past_model.onnx"), + ) + else: + onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output) # Saving the model config as this is needed sometimes. model.config.save_pretrained(args.output.parent) @@ -144,11 +162,22 @@ def main(): args.atol = args.atol[task.replace("-with-past", "")] try: - validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) + if model.config.is_encoder_decoder and args.for_ort: + validate_encoder_decoder_model_outputs( + onnx_config, + model, + onnx_outputs, + args.atol, + args.output.parent.joinpath("encoder_model.onnx"), + args.output.parent.joinpath("decoder_model.onnx"), + args.output.parent.joinpath("decoder_with_past_model.onnx"), + ) + else: + validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) except ValueError: - logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}") + logger.error(f"An error occured, but the model was saved at: {args.output.parent.as_posix()}") return - logger.info(f"All good, model saved at: {args.output.as_posix()}") + logger.info(f"All good, model saved at: {args.output.parent.as_posix()}") if __name__ == "__main__": diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 802a856003..f69bad215b 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -303,6 +303,18 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> """ return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))} + def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq + models which have the encoder and decoder exported as separate ONNX files. + Args: + reference_model_inputs ([`Mapping[str, Tensor]`): + Reference inputs for the model. + Returns: + `Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function + """ + return reference_model_inputs + class OnnxConfigWithPast(OnnxConfig, ABC): PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True @@ -454,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.decoder.value"] = t[1] flattened_output[f"{name}.{idx}.encoder.key"] = t[2] flattened_output[f"{name}.{idx}.encoder.value"] = t[3] + + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig: + """ + Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder + of the model separately. + + Args: + config (`PretrainedConfig`): + The encoder model's configuration to use when exporting to ONNX. + + Returns: + `OnnxConfig`: An instance of the ONNX configuration object. + """ + raise NotImplementedError( + f"{config.model_type} encoder export is not supported yet. ", + f"If you want to support {config.model_type} please propose a PR or open up an issue.", + ) + + def get_decoder_onnx_config( + self, config: "PretrainedConfig", task: str = "default", use_past: bool = False + ) -> OnnxConfig: + """ + Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder + of the model separately. + + Args: + config (`PretrainedConfig`): + The decoder model's configuration to use when exporting to ONNX. + task (`str`, defaults to `"default"`): + The task the model should be exported for. + use_past (`bool`, defaults to `False`): + Whether to export the model with past_key_values. + + Returns: + `OnnxConfig`: An instance of the ONNX configuration object. + """ + raise NotImplementedError( + f"{config.model_type} decoder export is not supported yet. ", + f"If you want to support {config.model_type} please propose a PR or open up an issue.", + ) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 306370d17d..b01a946cd9 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -17,14 +17,14 @@ from inspect import signature from itertools import chain from pathlib import Path -from typing import Iterable, List, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np from transformers.utils import is_tf_available, is_torch_available from ...utils import logging from .base import OnnxConfig -from .utils import MIN_TORCH_VERSION, is_torch_onnx_support_available +from .utils import MIN_TORCH_VERSION, get_encoder_decoder_models_for_export, is_torch_onnx_support_available if is_torch_available(): @@ -61,6 +61,58 @@ def check_dummy_inputs_are_allowed( ) +def validate_encoder_decoder_model_outputs( + config: OnnxConfig, + reference_model: Union["PreTrainedModel", "TFPreTrainedModel"], + onnx_named_outputs: List[str], + atol: float, + encoder_onnx_model: Path, + decoder_onnx_model: Path, + decoder_with_past_onnx_model: Optional[Path] = None, +): + """ + Validates the export by checking that the outputs from both the reference and the exported model match. + The following method validates the ONNX models exported using the `export_encoder_decoder_model` method. + + Args: + config ([`~OnnxConfig`]: + The configuration used to export the model. + reference_model ([`~PreTrainedModel`] or [`~TFPreTrainedModel`]): + The model used for the export. + onnx_named_outputs (`List[str]`): + The names of the outputs to check. + atol (`float`): + The absolute tolerance in terms of outputs difference between the reference and the exported model. + encoder_onnx_model (`Path`): + The path to the exported encoder ONNX model. + decoder_onnx_model (`Path`): + The path to the exported decoder ONNX model. + decoder_with_past_onnx_model (`Optional[Path]`, defaults to `None`): + The path to the exported decoder with past ONNX model. Required when `past_key_values` are exported. + Raises: + ValueError: If the outputs shapes or values do not match between the reference and the exported model. + """ + models_for_validation = get_encoder_decoder_models_for_export(reference_model, config) + + if len(onnx_named_outputs) != len(models_for_validation.keys()): + raise ValueError( + f"Invalid number of ONNX named outputs. Required {len(models_for_validation.keys())}, Provided {len(onnx_named_outputs)}" + ) + + # Validate encoder + model, onnx_config = models_for_validation["encoder"] + validate_model_outputs(onnx_config, model, encoder_onnx_model, onnx_named_outputs[0], atol) + + # Validate decoder + model, onnx_config = models_for_validation["decoder"] + validate_model_outputs(onnx_config, model, decoder_onnx_model, onnx_named_outputs[1], atol) + + if config.use_past: + # Validate decoder with past + model, onnx_config = models_for_validation["decoder_with_past"] + validate_model_outputs(onnx_config, model, decoder_with_past_onnx_model, onnx_named_outputs[2], atol) + + def validate_model_outputs( config: OnnxConfig, reference_model: Union["PreTrainedModel", "TFPreTrainedModel"], @@ -115,9 +167,12 @@ def validate_model_outputs( else: ref_outputs_dict[name] = value + # Create onnxruntime inputs from the reference model inputs + reference_model_inputs_for_validation = config.generate_dummy_inputs_for_validation(reference_model_inputs) + # We flatten potential collection of inputs (i.e. past_keys) onnx_inputs = {} - for name, value in reference_model_inputs.items(): + for name, value in reference_model_inputs_for_validation.items(): if isinstance(value, (list, tuple)): value = config.flatten_output_collection_property(name, value) onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) @@ -223,7 +278,9 @@ def export_pytorch( device = torch.device(device) if device.type == "cuda" and torch.cuda.is_available(): model.to(device) - dummy_inputs = tree_map(lambda value: value.to(device), dummy_inputs) + dummy_inputs = tree_map( + lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs + ) check_dummy_inputs_are_allowed(model, dummy_inputs) inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) @@ -321,6 +378,60 @@ def export_tensorflow( return input_names, output_names +def export_encoder_decoder_model( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + config: OnnxConfig, + opset: int, + encoder_output: Path, + decoder_output: Path, + decoder_with_past_output: Optional[Path] = None, + device: str = "cpu", +) -> Tuple[List[List[str]], List[List[str]]]: + """ + Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation. + The following method exports the encoder and decoder components of the model as separate + ONNX files. + + Args: + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + opset (`int`): + The version of the ONNX operator set to use. + encoder_output (`Path`): + Directory to store the exported encoder ONNX model. + decoder_output (`Path`): + Directory to store the exported decoder ONNX model. + decoder_with_past_output (`Optional[Path]`, defaults to `None`): + Directory to store the exported decoder with past ONNX model. Required when `past_key_values` are exported. + device (`str`, *optional*, defaults to `cpu`): + The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. + Returns: + `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named + inputs from the ONNX configuration. + """ + models_for_export = get_encoder_decoder_models_for_export(model, config) + outputs = [] + + # export encoder + model, onnx_config = models_for_export["encoder"] + outputs.append(export(model, onnx_config, opset, encoder_output, device=device)) + + # export decoder + model, onnx_config = models_for_export["decoder"] + outputs.append(export(model, onnx_config, opset, decoder_output, device=device)) + + if config.use_past: + # export decoder with past + model, onnx_config = models_for_export["decoder_with_past"] + outputs.append(export(model, onnx_config, opset, decoder_with_past_output, device=device)) + + outputs = list(map(list, zip(*outputs))) + return outputs + + def export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: OnnxConfig, diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 46926118a3..af0c6c88ef 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -49,6 +49,70 @@ from .base import PatchingSpec +class Seq2SeqEncoderOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + +class Seq2SeqDecoderOnnxConfig(TextSeq2SeqOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, + DummySeq2SeqPastKeyValuesGenerator, + ) + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "decoder_input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"}, + "encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"}, + "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, + } + + if self.use_past: + self.add_past_key_values(common_inputs, direction="inputs") + + return common_inputs + + @property + def torch_to_onnx_input_map(self) -> Mapping[str, str]: + return { + "decoder_input_ids": "input_ids", + "encoder_outputs": "encoder_hidden_states", + "attention_mask": "encoder_attention_mask", + } + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + self.add_past_key_values(common_outputs, direction="outputs") + return common_outputs + + @property + def values_override(self) -> Optional[Mapping[str, Any]]: + # Needed here because the configuration will actually be used with both use_past = True and use_past = False, + # but the cache must always be used regardless. + if hasattr(self._config, "use_cache"): + return {"use_cache": True} + + return None + + def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: + reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") + reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] + reference_model_inputs["encoder_attention_mask"] = reference_model_inputs.pop("attention_mask") + + return reference_model_inputs + + class BertOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 @@ -224,6 +288,23 @@ def generate(self, input_name: str, framework: str = "pt"): ] +class T5DecoderOnnxConfig(Seq2SeqDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( + hidden_size="d_model", + num_attention_heads="num_heads", + encoder_num_layers="num_layers", + decoder_num_layers="num_decoder_layers", + key_value_dim="d_kv", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, + T5DummySeq2SeqPastKeyValuesGenerator, + ) + + class T5OnnxConfig(TextSeq2SeqOnnxConfig): DEFAULT_ONNX_OPSET = 13 DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[:-1] + ( @@ -238,6 +319,14 @@ class T5OnnxConfig(TextSeq2SeqOnnxConfig): allow_new=True, ) + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> Seq2SeqEncoderOnnxConfig: + return Seq2SeqEncoderOnnxConfig(config, task="default") + + def get_decoder_onnx_config( + self, config: "PretrainedConfig", task: str = "default", use_past: bool = False + ) -> T5DecoderOnnxConfig: + return T5DecoderOnnxConfig(config, task, use_past=use_past) + class MT5OnnxConfig(T5OnnxConfig): ATOL_FOR_VALIDATION = 1e-4 @@ -286,6 +375,17 @@ def generate(self, input_name: str, framework: str = "pt"): return int_tensor +class BartDecoderOnnxConfig(Seq2SeqDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( + encoder_num_layers="encoder_layers", + decoder_num_layers="decoder_layers", + num_layers="decoder_layers", # Used for the causal-lm task past key values input generation. + encoder_num_attention_heads="encoder_attention_heads", + decoder_num_attention_heads="decoder_attention_heads", + eos_token_id="eos_token_id", + ) + + class BartOnnxConfig(TextSeq2SeqOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", @@ -425,6 +525,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output, name, idx, t ) + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> Seq2SeqEncoderOnnxConfig: + return Seq2SeqEncoderOnnxConfig(config, task="default") + + def get_decoder_onnx_config( + self, config: "PretrainedConfig", task: str = "default", use_past: bool = False + ) -> BartDecoderOnnxConfig: + return BartDecoderOnnxConfig(config, task, use_past=use_past) + class MBartOnnxConfig(BartOnnxConfig): pass @@ -442,8 +550,16 @@ class BlenderbotSmallOnnxConfig(BartOnnxConfig): pass +class BigBirdPegasusEncoderOnnxConfig(Seq2SeqEncoderOnnxConfig): + def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: + # TODO: check why the attention mask is not present in the exported model + reference_model_inputs.pop("attention_mask") + return reference_model_inputs + + class BigBirdPegasusOnnxConfig(BartOnnxConfig): - pass + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> BigBirdPegasusEncoderOnnxConfig: + return BigBirdPegasusEncoderOnnxConfig(config, task="default") class MarianOnnxConfig(BartOnnxConfig): @@ -639,26 +755,6 @@ def generate_dummy_inputs(self, framework: str = "pt"): return dummy_inputs -class WhisperOnnxConfig(TextAndAudioOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig - ATOL_FOR_VALIDATION = 1e-3 - - @property - def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = { - "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, - } - if self.use_past: - common_inputs["decoder_input_ids"] = {0: "batch_size"} - else: - common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} - - if self.use_past: - self.add_past_key_values(common_inputs, direction="inputs") - - return common_inputs - - class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig @@ -669,12 +765,12 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: } -class SpeechSeq2SeqDecoderOnnxConfig(OnnxSeq2SeqConfigWithPast): +class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator, - DummyDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, ) @@ -690,21 +786,36 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: return common_inputs - @property - def torch_to_onnx_input_map(self) -> Mapping[str, str]: - return {"decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states"} + def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: + reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids") + reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - self.add_past_key_values(common_outputs, direction="outputs") - return common_outputs + return reference_model_inputs + + +class WhisperOnnxConfig(TextAndAudioOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + ATOL_FOR_VALIDATION = 1e-3 @property - def values_override(self) -> Optional[Mapping[str, Any]]: - # Needed here because the configuration will actually be used with both use_past = True and use_past = False, - # but the cache must always be used regardless. - if hasattr(self._config, "use_cache"): - return {"use_cache": True} + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, + } + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch_size"} + else: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} - return None + if self.use_past: + self.add_past_key_values(common_inputs, direction="inputs") + + return common_inputs + + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> SpeechSeq2SeqEncoderOnnxConfig: + return SpeechSeq2SeqEncoderOnnxConfig(config, task="default") + + def get_decoder_onnx_config( + self, config: "PretrainedConfig", task: str = "default", use_past: bool = False + ) -> SpeechSeq2SeqDecoderOnnxConfig: + return SpeechSeq2SeqDecoderOnnxConfig(config, task, use_past=use_past) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 87c2acc4ef..247f1e6729 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -16,11 +16,21 @@ from ctypes import c_float, sizeof from enum import Enum +from typing import TYPE_CHECKING, Dict, Tuple, Union import packaging -from transformers.utils import is_torch_available +from transformers.utils import is_tf_available, is_torch_available +if TYPE_CHECKING: + from .base import OnnxConfig + + if is_torch_available(): + from transformers.modeling_utils import PreTrainedModel + + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + MIN_TORCH_VERSION = packaging.version.parse("1.11.0") TORCH_VERSION = None if is_torch_available(): @@ -69,3 +79,38 @@ def check_onnxruntime_requirements(minimum_version: packaging.version.Version): f"but we require the version to be >= {minimum_version} to enable all the conversions options.\n" "Please update ONNX Runtime by running `pip install --upgrade onnxruntime`" ) + + +def get_encoder_decoder_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig" +) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: + """ + Returns the encoder and decoder parts of the model and their subsequent onnx configs. + + Args: + model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): + The model to export. + config ([`~exporters.onnx.config.OnnxConfig`]): + The ONNX configuration associated with the exported model. + + Returns: + `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `OnnxConfig`]: A Dict containing the model and + onnx configs for the encoder and decoder parts of the model. + """ + models_for_export = dict() + + encoder_model = model.get_encoder() + encoder_onnx_config = config.get_encoder_onnx_config(encoder_model.config) + models_for_export["encoder"] = (encoder_model, encoder_onnx_config) + + decoder_model = model.get_decoder() + decoder_onnx_config = config.get_decoder_onnx_config(decoder_model.config, config.task, use_past=False) + models_for_export["decoder"] = (model, decoder_onnx_config) + + if config.use_past: + decoder_onnx_config_with_past = config.get_decoder_onnx_config( + decoder_model.config, config.task, use_past=True + ) + models_for_export["decoder_with_past"] = (model, decoder_onnx_config_with_past) + + return models_for_export diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 56e7ba8109..f1592144d0 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -325,7 +325,7 @@ def __init__( def generate(self, input_name: str, framework: str = "pt"): if input_name == "encoder_outputs": shape = (self.batch_size, self.sequence_length, self.hidden_size) - return (self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework), None, None) + return (self.random_float_tensor(shape, framework=framework), None, None) return super().generate(input_name, framework=framework) diff --git a/tests/exporters/test_onnx_export.py b/tests/exporters/test_onnx_export.py index 81f5063c1b..ee3a8b6307 100644 --- a/tests/exporters/test_onnx_export.py +++ b/tests/exporters/test_onnx_export.py @@ -21,7 +21,14 @@ from transformers import AutoConfig, is_tf_available, is_torch_available from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow -from optimum.exporters.onnx import OnnxConfig, OnnxConfigWithPast, export, validate_model_outputs +from optimum.exporters.onnx import ( + OnnxConfig, + OnnxConfigWithPast, + export, + export_encoder_decoder_model, + validate_encoder_decoder_model_outputs, + validate_model_outputs, +) from parameterized import parameterized @@ -108,6 +115,27 @@ ("roberta", "roberta-base"), } +PYTORCH_ENCODER_DECODER_MODELS_FOR_CONDITIONAL_GENERATION = { + ("bart", "facebook/bart-base", ("seq2seq-lm", "seq2seq-lm-with-past")), + ("mbart", "sshleifer/tiny-mbart", ("seq2seq-lm", "seq2seq-lm-with-past")), + ("t5", "t5-small"), + ("marian", "Helsinki-NLP/opus-mt-en-de", ("seq2seq-lm", "seq2seq-lm-with-past")), + # Not using google/mt5-small because it takes too much time for testing. + ("mt5", "lewtun/tiny-random-mt5"), + # Not using facebook/m2m100_418M because it takes too much time for testing. + ( + "m2m-100", + "hf-internal-testing/tiny-random-m2m_100", + ), + # Not using google/bigbird-pegasus-large-arxiv because it takes too much time for testing. + ( + "bigbird-pegasus", + "hf-internal-testing/tiny-random-bigbird_pegasus", + ("seq2seq-lm", "seq2seq-lm-with-past"), + ), + ("whisper", "openai/whisper-tiny.en"), +} + @require_onnx class OnnxUtilsTestCase(TestCase): @@ -219,7 +247,9 @@ class OnnxExportTestCase(TestCase): Integration tests ensuring supported models are correctly exported. """ - def _onnx_export(self, test_name, name, model_name, task, onnx_config_class_constructor, device="cpu"): + def _onnx_export( + self, test_name, name, model_name, task, onnx_config_class_constructor, device="cpu", for_ort=False + ): model_class = TasksManager.get_model_class_for_task(task) config = AutoConfig.from_pretrained(model_name) model = model_class.from_config(config) @@ -248,23 +278,52 @@ def _onnx_export(self, test_name, name, model_name, task, onnx_config_class_cons f" {onnx_config.MIN_TORCH_VERSION}, got: {TORCH_VERSION}" ) - with NamedTemporaryFile("w") as output: - try: - onnx_inputs, onnx_outputs = export( - model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, Path(output.name), device=device - ) - atol = onnx_config.ATOL_FOR_VALIDATION - if isinstance(atol, dict): - atol = atol[task.replace("-with-past", "")] - validate_model_outputs( - onnx_config, - model, - Path(output.name), - onnx_outputs, - atol, - ) - except (RuntimeError, ValueError) as e: - self.fail(f"{name}, {task} -> {e}") + atol = onnx_config.ATOL_FOR_VALIDATION + if isinstance(atol, dict): + atol = atol[task.replace("-with-past", "")] + + if for_ort: + with NamedTemporaryFile("w") as encoder_output, NamedTemporaryFile( + "w" + ) as decoder_output, NamedTemporaryFile("w") as decoder_with_past_output: + try: + onnx_inputs, onnx_outputs = export_encoder_decoder_model( + model, + onnx_config, + onnx_config.DEFAULT_ONNX_OPSET, + Path(encoder_output.name), + Path(decoder_output.name), + Path(decoder_with_past_output.name), + device=device, + ) + + validate_encoder_decoder_model_outputs( + onnx_config, + model, + onnx_outputs, + atol, + Path(encoder_output.name), + Path(decoder_output.name), + Path(decoder_with_past_output.name), + ) + except (RuntimeError, ValueError) as e: + self.fail(f"{name}, {task} -> {e}") + + else: + with NamedTemporaryFile("w") as output: + try: + onnx_inputs, onnx_outputs = export( + model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, Path(output.name), device=device + ) + validate_model_outputs( + onnx_config, + model, + Path(output.name), + onnx_outputs, + atol, + ) + except (RuntimeError, ValueError) as e: + self.fail(f"{name}, {task} -> {e}") @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) @slow @@ -280,6 +339,26 @@ def test_pytorch_export(self, test_name, name, model_name, task, onnx_config_cla def test_pytorch_export_on_cuda(self, test_name, name, model_name, task, onnx_config_class_constructor): self._onnx_export(test_name, name, model_name, task, onnx_config_class_constructor, device="cuda") + @parameterized.expand(_get_models_to_test(PYTORCH_ENCODER_DECODER_MODELS_FOR_CONDITIONAL_GENERATION)) + @slow + @require_torch + @require_vision + def test_pytorch_export_for_encoder_decoder_models_for_conditional_generation( + self, test_name, name, model_name, task, onnx_config_class_constructor + ): + self._onnx_export(test_name, name, model_name, task, onnx_config_class_constructor, for_ort=True) + + @parameterized.expand(_get_models_to_test(PYTORCH_ENCODER_DECODER_MODELS_FOR_CONDITIONAL_GENERATION)) + @slow + @require_torch + @require_vision + def test_pytorch_export_for_encoder_decoder_models_for_conditional_generation_on_cuda( + self, test_name, name, model_name, task, onnx_config_class_constructor + ): + self._onnx_export( + test_name, name, model_name, task, onnx_config_class_constructor, device="cuda", for_ort=True + ) + @parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_MODELS)) @slow @require_tf