From cb7ca378c5c84c585fcb3d621eb34892e3af31de Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 12 Jul 2025 09:01:16 +0000 Subject: [PATCH 1/5] Use num decoder layers instead of num hidden layers --- .../tools/transformers/models/whisper/convert_to_onnx.py | 2 +- .../tools/transformers/models/whisper/requirements.txt | 2 +- .../tools/transformers/models/whisper/whisper_decoder.py | 6 +++--- .../models/whisper/whisper_encoder_decoder_init.py | 4 ++-- .../tools/transformers/models/whisper/whisper_helper.py | 4 ++-- .../tools/transformers/models/whisper/whisper_inputs.py | 6 +++--- .../tools/transformers/models/whisper/whisper_jump_times.py | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index ac696ff3788aa..e092285d57358 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.num_hidden_layers, + model.config.decoder_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index f1758cc52280f..37fc72cd26e07 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers>=4.52.3 +transformers==4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index fadf271ae913b..a1feb4b89cfee 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -215,7 +215,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_self_{i}", f"present_value_self_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 26dc3aee7018b..cd81edc1001be 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f66aa22eb0972..a236c4da1738e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_layers: int, + num_decoder_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 0b0882eface72..8937fea900d14 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index a7c0d3538b8da..4dd5d7de1752b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], + *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], ] return input_names From be2b7073f259e444720497a776ac2c340c0fe93e Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 12 Jul 2025 17:48:17 +0000 Subject: [PATCH 2/5] Add changes suggested by linter --- .../python/tools/transformers/models/whisper/benchmark.py | 2 +- .../tools/transformers/models/whisper/whisper_decoder.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index a111db1edc257..625cfd863d3cd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -16,7 +16,6 @@ import numpy as np import psutil import torch -import whisper from benchmark_helper import measure_memory, setup_logger from onnxruntime_extensions import get_library_path from optimum.onnxruntime import ORTModelForSpeechSeq2Seq @@ -25,6 +24,7 @@ from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor import onnxruntime as ort +import whisper logger = logging.getLogger(__name__) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index a1feb4b89cfee..e10e616d35d38 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -214,8 +214,7 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") - for i in range(self.config.decoder_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) ) ), ] From c45d9b5219301571bacb6c735a6ce3e20057c4e7 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 12 Jul 2025 18:04:41 +0000 Subject: [PATCH 3/5] Move imports up --- .../python/tools/transformers/models/whisper/benchmark.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 625cfd863d3cd..ebac2ff9f82dc 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -13,9 +13,11 @@ import sys import time +import onnxruntime as ort import numpy as np import psutil import torch +import whisper from benchmark_helper import measure_memory, setup_logger from onnxruntime_extensions import get_library_path from optimum.onnxruntime import ORTModelForSpeechSeq2Seq @@ -23,9 +25,6 @@ from tqdm import trange from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor -import onnxruntime as ort -import whisper - logger = logging.getLogger(__name__) From 02853cc58e4561862fd6b7f50fbb2f624a206029 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 12 Jul 2025 18:21:11 +0000 Subject: [PATCH 4/5] Ignore ruff warning --- .../python/tools/transformers/models/whisper/benchmark.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index ebac2ff9f82dc..a04ad9ea17266 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- -import argparse +import argparse # noqa: I001 import ast import datetime import gc @@ -13,7 +13,6 @@ import sys import time -import onnxruntime as ort import numpy as np import psutil import torch @@ -25,6 +24,8 @@ from tqdm import trange from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor +import onnxruntime as ort + logger = logging.getLogger(__name__) From fb90ab0ef6b08251eef1f3083d84a570cde4c3fe Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 12 Jul 2025 18:47:03 +0000 Subject: [PATCH 5/5] Remove noqa for unsorted import block error --- .../python/tools/transformers/models/whisper/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index a04ad9ea17266..a111db1edc257 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -4,7 +4,7 @@ # license information. # -------------------------------------------------------------------------- -import argparse # noqa: I001 +import argparse import ast import datetime import gc