diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index 25922fa6d4..4b952bae77 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -91,6 +91,12 @@ def parse_args_onnx(parser): " and decoder-with-past models into a single ONNX model file to reduce memory usage." ), ) + optional_group.add_argument( + "--variant", + type=str, + default="default", + help=("Select a variant of the model to export."), + ) optional_group.add_argument( "--framework", type=str, @@ -233,5 +239,6 @@ def run(self): pad_token_id=self.args.pad_token_id, for_ort=self.args.for_ort, use_subprocess=True, + _variant=self.args.variant, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 9b58a8836e..6aaa1c1a1a 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -36,6 +36,7 @@ _get_submodels_for_export_stable_diffusion, get_decoder_models_for_export, get_encoder_decoder_models_for_export, + get_sam_models_for_export, get_stable_diffusion_models_for_export, ) @@ -61,6 +62,7 @@ def _get_submodels_and_onnx_configs( monolith: bool, custom_onnx_configs: Dict, custom_architecture: bool, + _variant: str, fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, ): @@ -75,6 +77,12 @@ def _get_submodels_and_onnx_configs( ) onnx_config = onnx_config_constructor(model.config, preprocessors=preprocessors) + onnx_config.variant = _variant + all_variants = "\n".join( + [f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()] + ) + logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}") + if ( model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS) @@ -83,6 +91,8 @@ def _get_submodels_and_onnx_configs( models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) elif task.startswith("text-generation") and not monolith: models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) + elif model.config.model_type == "sam": + models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) else: models_and_onnx_configs = {"model": (model, onnx_config)} @@ -156,6 +166,7 @@ def main_export( custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, use_subprocess: bool = False, + _variant: str = "default", **kwargs_shapes, ): """ @@ -230,6 +241,8 @@ def main_export( exporting on CUDA device, where ORT does not release memory at inference session destruction. When set to `True`, the `main_export` call should be guarded in `if __name__ == "__main__":` block. + _variant (`str`, defaults to `default`): + Specify the variant of the ONNX export to use. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -373,6 +386,7 @@ def main_export( custom_architecture=custom_architecture, fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, + _variant=_variant, ) if not is_stable_diffusion: diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 09ea0b4466..a1cf87e5e6 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -124,6 +124,7 @@ class OnnxConfig(ExportConfig, ABC): MIN_TORCH_VERSION = GLOBAL_MIN_TORCH_VERSION MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION PATCHING_SPECS: Optional[List["PatchingSpec"]] = None + VARIANTS = {"default": "The default ONNX variant."} _TASK_TO_COMMON_OUTPUTS = { "audio-classification": OrderedDict({"logits": {0: "batch_size"}}), "audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), @@ -226,6 +227,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task] return copy.deepcopy(common_outputs) + @property + def variant(self) -> str: + """ + For a given ONNX config, the variant of the model to export. This property allows to define variants of a given model, in case + different users would like to export the model differently (with different inputs/outputs, model splitted in several ONNX or not, etc.). + """ + return self._variant + + @variant.setter + def variant(self, value: str): + if value == "default" and hasattr(self, "DEFAULT_VARIANT"): + value = self.DEFAULT_VARIANT + if value not in self.VARIANTS: + raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.") + self._variant = value + def fix_dynamic_axes( self, model_path: "Path", device: str = "cpu", dtype: Optional[str] = None, input_shapes: Optional[Dict] = None ): @@ -276,6 +293,7 @@ def fix_dynamic_axes( input_shapes = {} dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes) dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names) + onnx_inputs = {} for name, value in dummy_inputs.items(): if isinstance(value, (list, tuple)): @@ -825,10 +843,16 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"} def flatten_past_key_values(self, flattened_output, name, idx, t): + if len(t) not in [2, 4]: + raise ValueError( + "past_key_values to flatten should be of length 2 (self-attention only) or 4 (self and cross attention)." + ) + flattened_output[f"{name}.{idx}.decoder.key"] = t[0] flattened_output[f"{name}.{idx}.decoder.value"] = t[1] - flattened_output[f"{name}.{idx}.encoder.key"] = t[2] - flattened_output[f"{name}.{idx}.encoder.value"] = t[3] + if len(t) == 4: + flattened_output[f"{name}.{idx}.encoder.key"] = t[2] + flattened_output[f"{name}.{idx}.encoder.value"] = t[3] def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None @@ -1017,10 +1041,16 @@ def flatten_decoder_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.value"] = t[1] def flatten_seq2seq_past_key_values(self, flattened_output, name, idx, t): - flattened_output[f"{name}.{idx}.decoder.key"] = t[0] - flattened_output[f"{name}.{idx}.decoder.value"] = t[1] - flattened_output[f"{name}.{idx}.encoder.key"] = t[2] - flattened_output[f"{name}.{idx}.encoder.value"] = t[3] + if len(t) not in [2, 4]: + raise ValueError( + "past_key_values to flatten should be of length 2 (self-attention only) or 4 (self and cross attention)." + ) + if len(t) == 2: + flattened_output[f"{name}.{idx}.decoder.key"] = t[0] + flattened_output[f"{name}.{idx}.decoder.value"] = t[1] + if len(t) == 4: + flattened_output[f"{name}.{idx}.encoder.key"] = t[2] + flattened_output[f"{name}.{idx}.encoder.value"] = t[3] def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: flattened_output = {} diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index cad2fdcb0f..265b81be44 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -326,11 +326,12 @@ def _run_validation( # Some models may modify in place the inputs, hence the copy. copy_reference_model_inputs = copy.deepcopy(reference_model_inputs) - if is_torch_available() and isinstance(reference_model, nn.Module): - with torch.inference_mode(): - ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs) - else: - ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs) + with config.patch_model_for_export(reference_model, model_kwargs=model_kwargs): + if is_torch_available() and isinstance(reference_model, nn.Module): + with torch.inference_mode(): + ref_outputs = reference_model(**copy_reference_model_inputs) + else: + ref_outputs = reference_model(**copy_reference_model_inputs) ref_outputs_dict = {} # We flatten potential collection of outputs (i.e. past_keys) to a flat structure @@ -564,10 +565,6 @@ def remap(value): if device.type == "cuda" and torch.cuda.is_available(): model.to(device) dummy_inputs = tree_map(remap, dummy_inputs) - check_dummy_inputs_are_allowed(model, dummy_inputs) - inputs = config.ordered_inputs(model) - input_names = list(inputs.keys()) - output_names = list(config.outputs.keys()) # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # so we check the torch version for backwards compatibility @@ -575,6 +572,12 @@ def remap(value): raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+") else: with config.patch_model_for_export(model, model_kwargs=model_kwargs): + check_dummy_inputs_are_allowed(model, dummy_inputs) + + inputs = config.ordered_inputs(model) + input_names = list(inputs.keys()) + output_names = list(config.outputs.keys()) + # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index eb5824316b..db0256e4d0 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -31,6 +31,7 @@ DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, + DummyVisionEmbeddingsGenerator, DummyVisionInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator, NormalizedConfig, @@ -53,7 +54,7 @@ TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) -from .model_patcher import WavLMModelPatcher +from .model_patcher import SAMModelPatcher, WavLMModelPatcher if TYPE_CHECKING: @@ -1217,34 +1218,62 @@ def inputs(self) -> Dict[str, Dict[int, str]]: class SamOnnxConfig(OnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0") + # Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429 + MIN_TORCH_VERSION = version.parse("2.0.99") NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig - DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator) - DEFAULT_ONNX_OPSET = 12 # einsum op not supported with opset 11 - MIN_TORCH_VERSION = version.parse("2.0.99") # See: https://github.com/huggingface/optimum/pull/1301 + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator) + DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9. + VARIANTS = { + "monolith": "All the SAM model components are exported as a single model.onnx.", + "split": "The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_encoder_mask_decoder.onnx. This allows to encoder the image only once for multiple point queries.", + } + DEFAULT_VARIANT = "split" def __init__( - self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + variant: str = "split", + vision_encoder: Optional[bool] = None, + preprocessors: Optional[List[Any]] = None, ): super().__init__(config, task, preprocessors=preprocessors) + self.variant = variant + self.vision_encoder = vision_encoder self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig(self._config.vision_config) @property def inputs(self) -> Dict[str, Dict[int, str]]: - inputs = { - "pixel_values": {0: "batch_size"}, - "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, - } - + if self.variant == "monolith": + inputs = { + "pixel_values": {0: "batch_size"}, + "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, + } + else: + if self.vision_encoder: + inputs = {"pixel_values": {0: "batch_size"}} + else: + inputs = { + "image_positional_embeddings": {0: "batch_size"}, + "image_embeddings": {0: "batch_size"}, + "input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"}, + } return inputs @property def outputs(self) -> Dict[str, Dict[int, str]]: - outputs = { - "iou_scores": {0: "batch_size", 1: "point_batch_size"}, - "pred_masks": {0: "batch_size", 1: "point_batch_size"}, - } + if self.variant == "split" and self.vision_encoder: + return {"image_embeddings": {0: "batch_size"}, "image_positional_embeddings": {0: "batch_size"}} + else: + return { + "iou_scores": {0: "batch_size", 1: "point_batch_size"}, + "pred_masks": {0: "batch_size", 1: "point_batch_size"}, + } - return outputs + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return SAMModelPatcher(self, model, model_kwargs=model_kwargs) class Pix2StructNormalizedConfig(NormalizedSeq2SeqConfig): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 7b8f1a238b..f8353b5924 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -17,6 +17,12 @@ import inspect from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + from ...utils import logging @@ -28,14 +34,19 @@ logger = logging.get_logger(__name__) -def overwride_arguments(args, kwargs, forward_signature, model_kwargs): +def override_arguments(args, kwargs, forward_signature, model_kwargs: Dict[str, Any]): + """ + Override the args and kwargs with the argument values from model_kwargs, following the signature forward_signature corresponding to args and kwargs. + """ args = list(args) for argument in model_kwargs: if argument in forward_signature.parameters: argument_index = list(forward_signature.parameters.keys()).index(argument) - - args[argument_index] = model_kwargs[argument] + if argument in kwargs or len(args) <= argument_index: + kwargs[argument] = model_kwargs[argument] + else: + args[argument_index] = model_kwargs[argument] else: kwargs[argument] = model_kwargs[argument] @@ -97,7 +108,7 @@ def __init__( @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) - args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) outputs = self.orig_forward(*args, **kwargs) @@ -159,7 +170,7 @@ def __init__( @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) - args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) outputs = self.orig_forward(*args, **kwargs) @@ -212,7 +223,7 @@ def patched_forward(*args, **kwargs): # that calls https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/functional.py#L5334 model_kwargs["output_attentions"] = True signature = inspect.signature(self.orig_forward) - args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=model_kwargs) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs) outputs = self.orig_forward(*args, **kwargs) @@ -228,3 +239,87 @@ def patched_forward(*args, **kwargs): return filterd_outputs self.patched_forward = patched_forward + + +class SAMModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + def patched_forward( + pixel_values=None, + input_points=None, + image_embeddings=None, + image_positional_embeddings=None, + return_dict=True, + **kwargs, + ): + if config.variant == "monolith": + return self.orig_forward( + pixel_values=pixel_values, + input_points=input_points, + image_embeddings=image_embeddings, + return_dict=return_dict, + **kwargs, + ) + elif config.variant == "split": + # return_dict = get_argument(args, kwargs, signature, "return_dict") + if config.vision_encoder: + # pixel_values = get_argument(args, kwargs, signature, "pixel_values") + image_positional_embeddings = model.get_image_wide_positional_embeddings() + + # repeat with batch size + batch_size = pixel_values.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_outputs = model.vision_encoder( + pixel_values, + output_attentions=False, + output_hidden_states=False, + return_dict=return_dict, + ) + image_embeddings = vision_outputs[0] + + if not return_dict: + return (image_embeddings, image_positional_embeddings) + else: + return { + "image_embeddings": image_embeddings, + "image_positional_embeddings": image_positional_embeddings, + } + else: + if input_points is not None: + input_labels = torch.ones_like( + input_points[:, :, :, 0], dtype=torch.int, device=input_points.device + ) + else: + raise ValueError("input_points is required to export the prompt encoder / mask decoder.") + + sparse_embeddings, dense_embeddings = model.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=None, # Not supported in the ONNX export + input_masks=None, # Not supported in the ONNX export + ) + + low_res_masks, iou_predictions, _ = model.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=True, # Not supported in the ONNX export + attention_similarity=None, # Not supported in the ONNX export + target_embedding=None, # Not supported in the ONNX export + output_attentions=False, + ) + + if not return_dict: + return (iou_predictions, low_res_masks) + else: + return {"iou_scores": iou_predictions, "pred_masks": low_res_masks} + + self.patched_forward = patched_forward diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 24a809a977..9bde57640c 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -310,6 +310,41 @@ def get_stable_diffusion_models_for_export( return models_for_export +def _get_submodels_for_export_sam(model, variant): + models_for_export = {} + + if variant == "monolith": + models_for_export["model"] = model + else: + # We use the model patcher to patch their forward method. + models_for_export["vision_encoder"] = model + models_for_export["prompt_encoder_mask_decoder"] = model + + return models_for_export + + +def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig"): + models_for_export = _get_submodels_for_export_sam(model, config.variant) + + if config.variant == "monolith": + onnx_config = config.__class__(model.config, task=config.task) + models_for_export["model"] = (models_for_export["model"], onnx_config) + else: + vision_encoder_onnx_config = config.__class__( + model.config, task=config.task, variant=config.variant, vision_encoder=True + ) + prompt_encoder_mask_decoder_onnx_config = config.__class__( + model.config, task=config.task, variant=config.variant, vision_encoder=False + ) + models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config) + models_for_export["prompt_encoder_mask_decoder"] = ( + models_for_export["prompt_encoder_mask_decoder"], + prompt_encoder_mask_decoder_onnx_config, + ) + + return models_for_export + + def override_diffusers_2_0_attn_processors(model): for _, submodule in model.named_modules(): if isinstance(submodule, Attention): diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 6117fe1895..d841ca1368 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -55,6 +55,7 @@ DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyTimestepInputGenerator, + DummyVisionEmbeddingsGenerator, DummyVisionInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator, ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index f4ff908847..56cd577664 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -702,6 +702,35 @@ def generate(self, input_name: str, framework: str = "pt"): return self.random_float_tensor(shape, framework=framework) +class DummyVisionEmbeddingsGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("image_positional_embeddings", "image_embeddings") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + image_embedding_size: Optional[int] = None, + output_channels: Optional[int] = None, + **kwargs, + ): + self.task = task + + self.batch_size = batch_size + self.image_embedding_size = ( + image_embedding_size + if image_embedding_size is not None + else normalized_config.prompt_encoder_config.image_embedding_size + ) + self.output_channels = ( + output_channels if output_channels is not None else normalized_config.vision_config.output_channels + ) + + def generate(self, input_name: str, framework: str = "pt"): + shape = [self.batch_size, self.output_channels, self.image_embedding_size, self.image_embedding_size] + return self.random_float_tensor(shape, framework=framework) + + class DummyPix2StructInputGenerator(DummyInputGenerator): """ Generates dummy time step inputs. diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 2d9ef98a26..1d25240c18 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -57,39 +57,70 @@ def _get_models_to_test(export_models_dict: Dict): for model_name, tasks in model_tasks.items(): for task in tasks: - models_to_test.append((f"{model_type}_{task}", model_type, model_name, task, False, False)) - - # -with-past and monolith cases are absurd, so we don't test them as not supported - if any( - task == ort_special_task - for ort_special_task in [ - "text-generation", - "text2text-generation", - "automatic-speech-recognition", - "image-to-text", - ] - ): - models_to_test.append( - (f"{model_type}_{task}_monolith", model_type, model_name, task, True, False) - ) + onnx_config_class = TasksManager.get_exporter_config_constructor( + "onnx", task=task, model_type=model_type + ) - # For other tasks, we don't test --no-post-process as there is none anyway - if task in [ - "feature-extraction-with-past", - "text-generation-with-past", - "automatic-speech-recognition-with-past", - "image-to-text-with-past", - "text2text-generation-with-past", - ]: + # Refer to https://github.com/huggingface/optimum/blob/0b08a1fd19005b7334aa923433b3544bd2b11ff2/optimum/exporters/tasks.py#L65 + if hasattr(onnx_config_class.func, "__self__"): + variants = onnx_config_class.func.__self__.VARIANTS + else: + variants = onnx_config_class.func.VARIANTS + + for variant in variants.keys(): models_to_test.append( - (f"{model_type}_{task}_no_postprocess", model_type, model_name, task, False, True) + (f"{model_type}_{task}_{variant}", model_type, model_name, task, variant, False, False) ) + # -with-past and monolith cases are absurd, so we don't test them as not supported + if any( + task == ort_special_task + for ort_special_task in [ + "text-generation", + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + ] + ): + models_to_test.append( + ( + f"{model_type}_{task}_monolith_{variant}", + model_type, + model_name, + task, + variant, + True, + False, + ) + ) + + # For other tasks, we don't test --no-post-process as there is none anyway + if task in [ + "feature-extraction-with-past", + "text-generation-with-past", + "automatic-speech-recognition-with-past", + "image-to-text-with-past", + "text2text-generation-with-past", + ]: + models_to_test.append( + ( + f"{model_type}_{task}_no_postprocess_{variant}", + model_type, + model_name, + task, + variant, + False, + True, + ) + ) + # TODO: segformer task can not be automatically inferred # TODO: xlm-roberta model auto-infers text-generation, but we don't support it # TODO: perceiver auto-infers default, but we don't support it (why?) if model_type not in ["segformer", "xlm-roberta", "perceiver", "vision-encoder-decoder"]: - models_to_test.append((f"{model_type}_no_task", model_type, model_name, "auto", False, False)) + models_to_test.append( + (f"{model_type}_no_task", model_type, model_name, "auto", "default", False, False) + ) return sorted(models_to_test) else: @@ -113,6 +144,7 @@ def _onnx_export( optimization_level: Optional[str] = None, device: str = "cpu", fp16: bool = False, + variant: str = "default", ): with TemporaryDirectory() as tmpdir: try: @@ -125,6 +157,7 @@ def _onnx_export( optimize=optimization_level, monolith=monolith, no_post_process=no_post_process, + _variant=variant, ) except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") @@ -158,7 +191,14 @@ def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: @require_torch @require_vision def test_exporters_cli_pytorch_cpu( - self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, ): # TODO: re-enable those tests # Failing due to https://github.com/huggingface/transformers/pull/22212 @@ -166,20 +206,31 @@ def test_exporters_cli_pytorch_cpu( # masked-im models use MaskedImageModelingOutput if model_type in ["vit", "deit"] and task == "masked-im": self.skipTest("Temporarily disabled upon transformers 4.28 release") - self._onnx_export(model_name, task, monolith, no_post_process) + self._onnx_export(model_name, task, monolith, no_post_process, variant=variant) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_vision @require_torch_gpu @pytest.mark.gpu_test def test_exporters_cli_pytorch_gpu( - self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, ): - # TODO: disable due to a bug in PyTorch: https://github.com/pytorch/pytorch/issues/95377 + # TODO: refer to https://github.com/pytorch/pytorch/issues/95377 if model_type == "yolos": self.skipTest("Export on cuda device fails for yolos due to a bug in PyTorch") - self._onnx_export(model_name, task, monolith, no_post_process, device="cuda") + # TODO: refer to https://github.com/pytorch/pytorch/issues/107591 + if model_type == "sam": + self.skipTest("sam export on cuda is not supported due to a bug in PyTorch") + + self._onnx_export(model_name, task, monolith, no_post_process, device="cuda", variant=variant) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_torch @@ -187,11 +238,20 @@ def test_exporters_cli_pytorch_gpu( @slow @pytest.mark.run_slow def test_exporters_cli_pytorch_with_optimization( - self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, ): for optimization_level in ["O1", "O2", "O3"]: try: - self._onnx_export(model_name, task, monolith, no_post_process, optimization_level=optimization_level) + self._onnx_export( + model_name, task, monolith, no_post_process, optimization_level=optimization_level, variant=variant + ) except NotImplementedError as e: if "Tried to use ORTOptimizer for the model type" in str( e @@ -207,14 +267,27 @@ def test_exporters_cli_pytorch_with_optimization( @pytest.mark.gpu_test @pytest.mark.run_slow def test_exporters_cli_pytorch_with_O4_optimization( - self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, ): - # TODO: disable due to a bug in PyTorch: https://github.com/pytorch/pytorch/issues/95377 + # TODO: refer to https://github.com/pytorch/pytorch/issues/95377 if model_type == "yolos": self.skipTest("Export on cuda device fails for yolos due to a bug in PyTorch") + # TODO: refer to https://github.com/pytorch/pytorch/issues/107591 + if model_type == "sam": + self.skipTest("sam export on cuda is not supported due to a bug in PyTorch") + try: - self._onnx_export(model_name, task, monolith, no_post_process, optimization_level="O4", device="cuda") + self._onnx_export( + model_name, task, monolith, no_post_process, optimization_level="O4", device="cuda", variant=variant + ) except NotImplementedError as e: if "Tried to use ORTOptimizer for the model type" in str( e @@ -288,6 +361,10 @@ def test_export_on_fp16( if model_type == "yolos": self.skipTest("yolos export on fp16 not supported due to a pytorch bug") + # TODO: refer to https://github.com/pytorch/pytorch/issues/107591 + if model_type == "sam": + self.skipTest("sam export on cuda is not supported due to a pytorch bug") + # TODO: refer to https://huggingface.slack.com/archives/C014N4749J9/p1677245766278129 if model_type == "deberta": self.skipTest("deberta export on fp16 not supported due to a transformers bug")