Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def parse_args_onnx(parser):
" and decoder-with-past models into a single ONNX model file to reduce memory usage."
),
)
optional_group.add_argument(
"--variant",
type=str,
default="default",
help=("Select a variant of the model to export."),
)
Comment on lines +94 to +99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a set of possible choices here.

Copy link
Contributor Author

@fxmarty fxmarty Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a bit tricky given that the choices are dynamic (dependent on the onnx config).

optional_group.add_argument(
"--framework",
type=str,
Expand Down Expand Up @@ -233,5 +239,6 @@ def run(self):
pad_token_id=self.args.pad_token_id,
for_ort=self.args.for_ort,
use_subprocess=True,
_variant=self.args.variant,
**input_shapes,
)
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_get_submodels_for_export_stable_diffusion,
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_sam_models_for_export,
get_stable_diffusion_models_for_export,
)

Expand All @@ -61,6 +62,7 @@ def _get_submodels_and_onnx_configs(
monolith: bool,
custom_onnx_configs: Dict,
custom_architecture: bool,
_variant: str,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make it a protected parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking to keep it "private" for now, and support it correctly once we move to this API fully instead of -with-past, monolith, etc.

fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
):
Expand All @@ -75,6 +77,12 @@ def _get_submodels_and_onnx_configs(
)
onnx_config = onnx_config_constructor(model.config, preprocessors=preprocessors)

onnx_config.variant = _variant
all_variants = "\n".join(
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

if (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
Expand All @@ -83,6 +91,8 @@ def _get_submodels_and_onnx_configs(
models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config)
elif task.startswith("text-generation") and not monolith:
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

Expand Down Expand Up @@ -156,6 +166,7 @@ def main_export(
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
use_subprocess: bool = False,
_variant: str = "default",
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -230,6 +241,8 @@ def main_export(
exporting on CUDA device, where ORT does not release memory at inference session
destruction. When set to `True`, the `main_export` call should be guarded in
`if __name__ == "__main__":` block.
_variant (`str`, defaults to `default`):
Specify the variant of the ONNX export to use.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -373,6 +386,7 @@ def main_export(
custom_architecture=custom_architecture,
fn_get_submodels=fn_get_submodels,
preprocessors=preprocessors,
_variant=_variant,
)

if not is_stable_diffusion:
Expand Down
42 changes: 36 additions & 6 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class OnnxConfig(ExportConfig, ABC):
MIN_TORCH_VERSION = GLOBAL_MIN_TORCH_VERSION
MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
VARIANTS = {"default": "The default ONNX variant."}
_TASK_TO_COMMON_OUTPUTS = {
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
Expand Down Expand Up @@ -226,6 +227,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = self._TASK_TO_COMMON_OUTPUTS[self.task]
return copy.deepcopy(common_outputs)

@property
def variant(self) -> str:
"""
For a given ONNX config, the variant of the model to export. This property allows to define variants of a given model, in case
different users would like to export the model differently (with different inputs/outputs, model splitted in several ONNX or not, etc.).
"""
return self._variant

@variant.setter
def variant(self, value: str):
if value == "default" and hasattr(self, "DEFAULT_VARIANT"):
value = self.DEFAULT_VARIANT
if value not in self.VARIANTS:
raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.")
self._variant = value

def fix_dynamic_axes(
self, model_path: "Path", device: str = "cpu", dtype: Optional[str] = None, input_shapes: Optional[Dict] = None
):
Expand Down Expand Up @@ -276,6 +293,7 @@ def fix_dynamic_axes(
input_shapes = {}
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)

onnx_inputs = {}
for name, value in dummy_inputs.items():
if isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -825,10 +843,16 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length_out"}

def flatten_past_key_values(self, flattened_output, name, idx, t):
if len(t) not in [2, 4]:
raise ValueError(
"past_key_values to flatten should be of length 2 (self-attention only) or 4 (self and cross attention)."
)

flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]
if len(t) == 4:
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -1017,10 +1041,16 @@ def flatten_decoder_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.value"] = t[1]

def flatten_seq2seq_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]
if len(t) not in [2, 4]:
raise ValueError(
"past_key_values to flatten should be of length 2 (self-attention only) or 4 (self and cross attention)."
)
if len(t) == 2:
flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
if len(t) == 4:
flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
flattened_output[f"{name}.{idx}.encoder.value"] = t[3]

def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
flattened_output = {}
Expand Down
21 changes: 12 additions & 9 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,12 @@ def _run_validation(
# Some models may modify in place the inputs, hence the copy.
copy_reference_model_inputs = copy.deepcopy(reference_model_inputs)

if is_torch_available() and isinstance(reference_model, nn.Module):
with torch.inference_mode():
ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs)
else:
ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs)
with config.patch_model_for_export(reference_model, model_kwargs=model_kwargs):
if is_torch_available() and isinstance(reference_model, nn.Module):
with torch.inference_mode():
ref_outputs = reference_model(**copy_reference_model_inputs)
else:
ref_outputs = reference_model(**copy_reference_model_inputs)
ref_outputs_dict = {}

# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
Expand Down Expand Up @@ -564,17 +565,19 @@ def remap(value):
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(remap, dummy_inputs)
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+")
else:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
Expand Down
59 changes: 44 additions & 15 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
NormalizedConfig,
Expand All @@ -53,7 +54,7 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import WavLMModelPatcher
from .model_patcher import SAMModelPatcher, WavLMModelPatcher


if TYPE_CHECKING:
Expand Down Expand Up @@ -1217,34 +1218,62 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

class SamOnnxConfig(OnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0")
# Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429
MIN_TORCH_VERSION = version.parse("2.0.99")
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator)
DEFAULT_ONNX_OPSET = 12 # einsum op not supported with opset 11
MIN_TORCH_VERSION = version.parse("2.0.99") # See: https://github.com/huggingface/optimum/pull/1301
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator)
DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9.
VARIANTS = {
"monolith": "All the SAM model components are exported as a single model.onnx.",
"split": "The vision encoder is exported as a separate vision_encoder.onnx, and the prompt encoder and mask decoder are exported as a prompt_encoder_mask_decoder.onnx. This allows to encoder the image only once for multiple point queries.",
}
DEFAULT_VARIANT = "split"

def __init__(
self, config: "PretrainedConfig", task: str = "feature-extraction", preprocessors: Optional[List[Any]] = None
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
variant: str = "split",
vision_encoder: Optional[bool] = None,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(config, task, preprocessors=preprocessors)
self.variant = variant
self.vision_encoder = vision_encoder
self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig(self._config.vision_config)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
inputs = {
"pixel_values": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}

if self.variant == "monolith":
inputs = {
"pixel_values": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}
else:
if self.vision_encoder:
inputs = {"pixel_values": {0: "batch_size"}}
else:
inputs = {
"image_positional_embeddings": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
"input_points": {0: "batch_size", 1: "point_batch_size", 2: "nb_points_per_image"},
}
return inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
outputs = {
"iou_scores": {0: "batch_size", 1: "point_batch_size"},
"pred_masks": {0: "batch_size", 1: "point_batch_size"},
}
if self.variant == "split" and self.vision_encoder:
return {"image_embeddings": {0: "batch_size"}, "image_positional_embeddings": {0: "batch_size"}}
else:
return {
"iou_scores": {0: "batch_size", 1: "point_batch_size"},
"pred_masks": {0: "batch_size", 1: "point_batch_size"},
}

return outputs
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SAMModelPatcher(self, model, model_kwargs=model_kwargs)


class Pix2StructNormalizedConfig(NormalizedSeq2SeqConfig):
Expand Down
Loading