Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
new_axes_names[axis_idx] = axis_name
common_outputs[name] = new_axes_names

if self._behavior is not ConfigBehavior.ENCODER:
common_outputs["encoder_last_hidden_state"] = {0: "batch_size", 1: "encoder_sequence_length"}

if self.use_present_in_outputs:
self.add_past_key_values(common_outputs, direction="outputs")

Expand Down Expand Up @@ -759,9 +756,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch_size", 2: "encoder_sequence_length"}
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch_size", 2: "encoder_sequence_length"}

if direction == "outputs" and "encoder_last_hidden_state" in inputs_or_outputs:
inputs_or_outputs.move_to_end("encoder_last_hidden_state")

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
Expand Down Expand Up @@ -823,8 +817,16 @@ def post_process_exported_models(

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
if "decoder_input_ids" in reference_model_inputs:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")

if "encoder_outputs" in reference_model_inputs:
if self.use_past_in_inputs is False or self.is_merged:
# ONNX without past uses encoder_hidden_states even when we don't outputing them
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
else:
# ONNX with past does not use encoder_hidden_states when we don't output them
reference_model_inputs.pop("encoder_outputs")

return super().generate_dummy_inputs_for_validation(reference_model_inputs)

Expand Down
11 changes: 2 additions & 9 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str,
else:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")

# for encoder-decoder custom models, always pass encoder_hidden_states as input
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return self._decoder_onnx_config.generate_dummy_inputs_for_validation(reference_model_inputs)
Expand Down Expand Up @@ -392,12 +394,3 @@ def post_process_exported_models(
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
# This is handled by OnnxSeq2SeqConfigWithPast, but not by OnnxConfigWithPast, so we take care of this here to
# make sure this output is moved at the end.
if "encoder_last_hidden_state" in common_outputs:
common_outputs.move_to_end("encoder_last_hidden_state")
return common_outputs
14 changes: 11 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,16 @@ class T5OnnxConfig(TextSeq2SeqOnnxConfig):
allow_new=True,
)

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")

# T5 requires encoder_hidden_states as an input for both the without/with past models,
# which is different than other architectures that require it only for the without past case
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]

return super().generate_dummy_inputs_for_validation(reference_model_inputs)


class MT5OnnxConfig(T5OnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
Expand Down Expand Up @@ -443,8 +453,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.task != "causal-lm":
common_outputs["encoder_last_hidden_state"] = {0: "batch_size", 1: "sequence_length"}
if self.use_present_in_outputs:
for i in range(self._normalized_config.encoder_num_layers):
common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"}
Expand Down Expand Up @@ -944,7 +952,7 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if self._behavior is ConfigBehavior.DECODER:
if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

Expand Down
6 changes: 5 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ def patched_forward(*args, **kwargs):
allow_past_in_outputs and name.startswith("past_key_values")
):
if name != "past_key_values":
filterd_outputs[name] = value
if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state":
# Who cares about the encoder outputs in the decoder?
continue
else:
filterd_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder" and self.real_config.use_past is False
Expand Down
21 changes: 14 additions & 7 deletions optimum/onnx/transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool):

for idx in range(len(model1.graph.output)):
model_output_1 = model1.graph.output[idx]
model_output_2 = model2.graph.output[idx]
if model_output_1 != model_output_2:
if not (
model_output_2 = model2.graph.output[idx] if idx < len(model2.graph.output) else None

if model_output_2 is None or model_output_1 != model_output_2:
if model_output_2 is None or not (
model_output_1.name == model_output_2.name
and model_output_1.type.tensor_type.elem_type == model_output_2.type.tensor_type.elem_type
):
Expand Down Expand Up @@ -205,10 +206,16 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool):
)
model2.graph.output.insert(idx, constant_empty_output)
else:
raise ValueError(
f"Cannot match {model_output_1.name} with {model_output_2.name}. Make sure your"
f" model protos have same outputs, have same data types and are in the same order."
)
if model_output_2 is not None:
raise ValueError(
f"Cannot match {model_output_1.name} with {model_output_2.name}. Make sure your"
f" model protos have same outputs, have same data types and are in the same order."
)
else:
raise ValueError(
f"Too few outputs of model2 were found to match with {model_output_1.name}."
f" Please try to pass strict=False, or fill a bug report at https://github.com/huggingface/optimum."
)
else:
model2.graph.output.remove(model_output_2)

Expand Down
4 changes: 2 additions & 2 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"levit": "hf-internal-testing/tiny-random-LevitModel",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"longt5": "fxmarty/tiny-random-working-LongT5Model",
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
Expand Down Expand Up @@ -163,7 +163,7 @@
"levit": "facebook/levit-128S",
"layoutlm": "microsoft/layoutlm-base-uncased",
"layoutlmv3": "microsoft/layoutlmv3-base",
"longt5": "hf-internal-testing/tiny-random-longt5", # Not using google/long-t5-local-base because it takes too much time for testing.
"longt5": "fxmarty/tiny-random-working-LongT5Model", # Not using google/long-t5-local-base because it takes too much time for testing.
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100", # Not using facebook/m2m100_418M because it takes too much time for testing.
"marian": "Helsinki-NLP/opus-mt-en-de",
Expand Down