diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index e05f68b2c9bcd..b46045444ee12 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2090,8 +2090,8 @@ def _infer_Attention(self, node): # noqa: N802 # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length input_shape = self._get_shape(node, 0) - past_shape = self._get_shape(node, 4) if node.input[4] else [] - mask_shape = self._get_shape(node, 3) if node.input[3] else [] + past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] if past_shape and len(past_shape) == 5: if mask_shape and len(mask_shape) in [2, 3]: diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 436317599a75b..e8df6bc78533b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -205,10 +205,11 @@ def export_onnx_models( if overwrite or not os.path.exists(onnx_path): logger.info(f"Exporting ONNX model to {onnx_path}") # We have to clone model before exporting onnx, otherwise verify_onnx will report large difference. - cloned_model = copy.deepcopy(model).to(device) + device_to_export = torch.device("cpu") + cloned_model = copy.deepcopy(model).to(device_to_export) WhisperHelper.export_onnx( cloned_model, - device, + device_to_export, onnx_path, verbose, use_external_data_format, @@ -292,7 +293,7 @@ def main(): assert args.use_gpu, "fp16 requires --use_gpu" if args.optimize_onnx: - logger.warning("Graph optimization for Whisper is not implemented yet.") + logger.warning("Applying graph optimization for Whisper...") output_paths = export_onnx_models( args.model_name_or_path, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index bbdafdcb3e9c2..f756adf5c9141 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -107,5 +107,6 @@ def chain_model(args): save_as_external_data=True, all_tensors_to_one_file=True, convert_attribute=True, + location=f"{os.path.basename(args.beam_model_output_dir)}.data", ) onnx.checker.check_model(args.beam_model_output_dir, full_check=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 6d444bdd87a74..c795e36498efe 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -229,10 +229,8 @@ def optimize_onnx( from fusion_options import FusionOptions - optimization_options = None - if is_float16: - optimization_options = FusionOptions("bart") - optimization_options.enable_skip_layer_norm = False + optimization_options = FusionOptions("bart") + optimization_options.use_multi_head_attention = True m = optimize_model( onnx_model_path, @@ -241,7 +239,7 @@ def optimize_onnx( hidden_size=hidden_size, opt_level=2 if not use_external_data_format else None, optimization_options=optimization_options, - use_gpu=False, + use_gpu=use_gpu, only_onnxruntime=False, ) @@ -262,4 +260,4 @@ def verify_onnx( ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" # Not implemented for Whisper currently - return True + return 0