Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions onnxruntime/python/tools/transformers/models/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp16 --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
```

Export + Quantize for INT8 CUDA
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk

# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
```

Export + Quantize for INT8 CPU
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk

# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
```

## Exporting Whisper with Beam Search

There are several ways to export Whisper with beam search.
Expand Down Expand Up @@ -143,13 +161,22 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
```

Export + Quantize for INT8
Export + Quantize for INT8 CUDA
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda

# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda
```

Export + Quantize for INT8 CPU
```
# From source:
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu

# From wheel:
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu
```

Note: INT8 CPU is not compatible with `--output_cross_qk`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def parse_args():
type=str,
required=True,
default="fp32",
choices=["int8", "fp16", "fp32"],
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)

Expand Down Expand Up @@ -579,7 +579,7 @@ def main():
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu")

setattr(args, "processor", processor) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_args():
"--precision",
type=str,
required=True,
choices=["int8", "fp16", "fp32"],
choices=["int4", "int8", "fp16", "fp32"],
help="Precision to run model",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
import logging
import os

import onnx
import torch
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
from whisper_chain import chain_model
from whisper_encoder import WhisperEncoder
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper

from onnxruntime import quantization
from onnxruntime.quantization.matmul_nbits_quantizer import (
KQuantWeightOnlyQuantConfig,
MatMulNBitsQuantizer,
QuantFormat,
)

logger = logging.getLogger("")

Expand Down Expand Up @@ -94,8 +99,8 @@ def parse_arguments(argv=None):
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization",
)

conversion_args.add_argument(
Expand Down Expand Up @@ -289,28 +294,20 @@ def parse_arguments(argv=None):
###################################

quant_args.add_argument(
"--quantize_embedding_layer",
required=False,
action="store_true",
help="Quantize MatMul, GEMM, and Gather.",
)
quant_args.set_defaults(quantize_embedding_layer=False)

quant_args.add_argument(
"--quantize_per_channel",
"--accuracy_level",
default=0,
required=False,
action="store_true",
help="Quantize weights per each channel.",
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation.",
)
quant_args.set_defaults(quantize_per_channel=False)

quant_args.add_argument(
"--quantize_reduce_range",
"--quantize_symmetric",
required=False,
action="store_true",
help="Quantize weights with 7 bits.",
help="Quantize weights symmetrically",
)
quant_args.set_defaults(quantize_reduce_range=False)
quant_args.set_defaults(quantize_symmetric=False)

args = parser.parse_args(argv)

Expand All @@ -323,6 +320,22 @@ def parse_arguments(argv=None):
return args


# quant_method is reserved for mixed precision in future
def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None):
customized_weight_config = {}
quant_algo_config = None

# need to use k_quant for int8
if precision == Precision.INT8:
for node_name in matmul_nodes:
customized_weight_config[node_name] = {"bits": 8}
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
else:
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)

return quant_algo_config


def export_onnx_models(
model_name_or_path,
model_impl,
Expand All @@ -340,19 +353,21 @@ def export_onnx_models(
output_qk: bool = False,
overwrite: bool = False,
use_int32_inputs: bool = True,
quantize_embedding_layer: bool = False,
quantize_per_channel: bool = False,
quantize_reduce_range: bool = False,
accuracy_level: int = 0,
quantize_symmetric: bool = False,
provider: str = "cpu",
):
device = torch.device("cuda" if use_gpu else "cpu")
if not use_gpu:
accuracy_level = 4 # change to 4 for CPU EP
use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu)

models = WhisperHelper.load_model(
model_name_or_path,
model_impl,
cache_dir,
device,
torch.float16 if precision == Precision.FLOAT16 else torch.float32,
torch.float16 if use_fp16_inputs else torch.float32,
merge_encoder_and_decoder_init,
no_beam_search_op,
output_qk,
Expand Down Expand Up @@ -384,7 +399,7 @@ def export_onnx_models(
PROVIDERS[provider],
verbose,
use_external_data_format,
use_fp16_inputs=(precision == Precision.FLOAT16),
use_fp16_inputs=use_fp16_inputs,
use_int32_inputs=use_int32_inputs,
use_encoder_hidden_states=(name == "decoder_init"),
use_kv_cache_inputs=(name == "decoder"),
Expand Down Expand Up @@ -430,27 +445,43 @@ def export_onnx_models(
model.verify_onnx(
onnx_path,
PROVIDERS[provider],
use_fp16_inputs=(precision == Precision.FLOAT16),
use_fp16_inputs=use_fp16_inputs,
)
else:
model.verify_onnx(
onnx_path,
PROVIDERS[provider],
use_fp16_inputs=(precision == Precision.FLOAT16),
use_fp16_inputs=use_fp16_inputs,
use_int32_inputs=use_int32_inputs,
)

if precision == Precision.INT8:
quantization.quantize_dynamic(
onnx_path,
if precision in (Precision.INT8, Precision.INT4):
onnx_model = onnx.load(onnx_path, load_external_data=True)
matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"]
quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes)

quant = MatMulNBitsQuantizer(
model=onnx_model,
block_size=32,
is_symmetric=quantize_symmetric,
accuracy_level=accuracy_level,
quant_format=QuantFormat.QOperator,
op_types_to_quantize=("MatMul",),
algo_config=quant_algo_config,
)
quant.process()
if os.path.exists(output_path):
os.remove(output_path)
if os.path.exists(output_path + ".data"):
os.remove(output_path + ".data")
onnx.save_model(
quant.model.model,
output_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
),
use_external_data_format=use_external_data_format,
per_channel=quantize_per_channel,
reduce_range=quantize_reduce_range,
extra_options={"MatMulConstBOnly": True},
save_as_external_data=True,
all_tensors_to_one_file=True,
location=os.path.basename(output_path) + ".data",
size_threshold=0,
convert_attribute=False,
)
else:
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
Expand Down Expand Up @@ -493,9 +524,8 @@ def main(argv=None):
args.output_cross_qk,
args.overwrite,
not args.use_int64_inputs,
args.quantize_embedding_layer,
args.quantize_per_channel,
args.quantize_reduce_range,
args.accuracy_level,
args.quantize_symmetric,
args.provider,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=2.7.0
torch==2.7.0
transformers==4.52.3
openai-whisper==20240927
ffmpeg-python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@ def chain_model(args):
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)

use_fp16_inputs = args.precision == Precision.FLOAT16 or (
args.precision in (Precision.INT8, Precision.INT4) and args.use_gpu
)
# Create inputs/outputs for WhisperBeamSearch op
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
temperature_name = "temperature_fp16" if use_fp16_inputs else "temperature"
beam_inputs = [
"input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
"input_features_fp16" if use_fp16_inputs else "input_features",
"max_length",
"min_length",
"num_beams",
"num_return_sequences",
"length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
"length_penalty_fp16" if use_fp16_inputs else "length_penalty",
"repetition_penalty_fp16" if use_fp16_inputs else "repetition_penalty",
"vocab_mask" if args.use_vocab_mask else "",
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
"", # attention mask
Expand All @@ -74,8 +77,8 @@ def chain_model(args):
temperature_name if args.use_temperature else "",
]

sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
sequence_scores_name = "sequence_scores_fp16" if use_fp16_inputs else "sequence_scores"
scores_name = "scores_fp16" if use_fp16_inputs else "scores"
beam_outputs = [
"sequences",
sequence_scores_name if args.output_sequence_scores else "",
Expand All @@ -85,7 +88,7 @@ def chain_model(args):
]

graph_nodes = []
if args.precision == Precision.FLOAT16:
if use_fp16_inputs:
input_features_cast_node = helper.make_node(
"Cast",
inputs=["input_features"],
Expand Down
Loading
Loading