Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,8 +2090,8 @@ def _infer_Attention(self, node): # noqa: N802
# mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
# present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
input_shape = self._get_shape(node, 0)
past_shape = self._get_shape(node, 4) if node.input[4] else []
mask_shape = self._get_shape(node, 3) if node.input[3] else []
past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []

if past_shape and len(past_shape) == 5:
if mask_shape and len(mask_shape) in [2, 3]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ def export_onnx_models(
if overwrite or not os.path.exists(onnx_path):
logger.info(f"Exporting ONNX model to {onnx_path}")
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
cloned_model = copy.deepcopy(model).to(device)
device_to_export = torch.device("cpu")
cloned_model = copy.deepcopy(model).to(device_to_export)
WhisperHelper.export_onnx(
cloned_model,
device,
device_to_export,
onnx_path,
verbose,
use_external_data_format,
Expand Down Expand Up @@ -292,7 +293,7 @@ def main():
assert args.use_gpu, "fp16 requires --use_gpu"

if args.optimize_onnx:
logger.warning("Graph optimization for Whisper is not implemented yet.")
logger.warning("Applying graph optimization for Whisper...")

output_paths = export_onnx_models(
args.model_name_or_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,6 @@ def chain_model(args):
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=True,
location=f"{os.path.basename(args.beam_model_output_dir)}.data",
)
onnx.checker.check_model(args.beam_model_output_dir, full_check=True)
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,8 @@ def optimize_onnx(

from fusion_options import FusionOptions

optimization_options = None
if is_float16:
optimization_options = FusionOptions("bart")
optimization_options.enable_skip_layer_norm = False
optimization_options = FusionOptions("bart")
optimization_options.use_multi_head_attention = True

m = optimize_model(
onnx_model_path,
Expand All @@ -241,7 +239,7 @@ def optimize_onnx(
hidden_size=hidden_size,
opt_level=2 if not use_external_data_format else None,
optimization_options=optimization_options,
use_gpu=False,
use_gpu=use_gpu,
only_onnxruntime=False,
)

Expand All @@ -262,4 +260,4 @@ def verify_onnx(
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
# Not implemented for Whisper currently
return True
return 0