diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 598eeea8d2e49..9056ac07cc286 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -75,6 +75,24 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp16 --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk ``` +Export + Quantize for INT8 CUDA +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk +``` + +Export + Quantize for INT8 CPU +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk +``` + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search. @@ -143,13 +161,22 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda ``` -Export + Quantize for INT8 +Export + Quantize for INT8 CUDA +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda +``` + +Export + Quantize for INT8 CPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu ``` Note: INT8 CPU is not compatible with `--output_cross_qk`. diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index a111db1edc257..88fdad01baf92 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -449,7 +449,7 @@ def parse_args(): type=str, required=True, default="fp32", - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision for model. For ONNX models, the model's precision should be set before running this script.", ) @@ -579,7 +579,7 @@ def main(): config = WhisperConfig.from_pretrained(args.model_name) processor = WhisperProcessor.from_pretrained(args.model_name) target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device - use_fp16 = args.precision == "fp16" + use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu") setattr(args, "processor", processor) # noqa: B010 setattr(args, "target_device", target_device) # noqa: B010 diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d2eb0d5259254..95d4b60fead99 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -97,7 +97,7 @@ def get_args(): "--precision", type=str, required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision to run model", ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e092285d57358..38fbd73e9c119 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -8,13 +8,18 @@ import logging import os +import onnx import torch from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger from whisper_chain import chain_model from whisper_encoder import WhisperEncoder from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper -from onnxruntime import quantization +from onnxruntime.quantization.matmul_nbits_quantizer import ( + KQuantWeightOnlyQuantConfig, + MatMulNBitsQuantizer, + QuantFormat, +) logger = logging.getLogger("") @@ -94,8 +99,8 @@ def parse_arguments(argv=None): required=False, type=Precision, default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8], - help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", + choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4], + help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization", ) conversion_args.add_argument( @@ -289,28 +294,20 @@ def parse_arguments(argv=None): ################################### quant_args.add_argument( - "--quantize_embedding_layer", - required=False, - action="store_true", - help="Quantize MatMul, GEMM, and Gather.", - ) - quant_args.set_defaults(quantize_embedding_layer=False) - - quant_args.add_argument( - "--quantize_per_channel", + "--accuracy_level", + default=0, required=False, - action="store_true", - help="Quantize weights per each channel.", + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation.", ) - quant_args.set_defaults(quantize_per_channel=False) quant_args.add_argument( - "--quantize_reduce_range", + "--quantize_symmetric", required=False, action="store_true", - help="Quantize weights with 7 bits.", + help="Quantize weights symmetrically", ) - quant_args.set_defaults(quantize_reduce_range=False) + quant_args.set_defaults(quantize_symmetric=False) args = parser.parse_args(argv) @@ -323,6 +320,22 @@ def parse_arguments(argv=None): return args +# quant_method is reserved for mixed precision in future +def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None): + customized_weight_config = {} + quant_algo_config = None + + # need to use k_quant for int8 + if precision == Precision.INT8: + for node_name in matmul_nodes: + customized_weight_config[node_name] = {"bits": 8} + quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + else: + quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + + return quant_algo_config + + def export_onnx_models( model_name_or_path, model_impl, @@ -340,19 +353,21 @@ def export_onnx_models( output_qk: bool = False, overwrite: bool = False, use_int32_inputs: bool = True, - quantize_embedding_layer: bool = False, - quantize_per_channel: bool = False, - quantize_reduce_range: bool = False, + accuracy_level: int = 0, + quantize_symmetric: bool = False, provider: str = "cpu", ): device = torch.device("cuda" if use_gpu else "cpu") + if not use_gpu: + accuracy_level = 4 # change to 4 for CPU EP + use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu) models = WhisperHelper.load_model( model_name_or_path, model_impl, cache_dir, device, - torch.float16 if precision == Precision.FLOAT16 else torch.float32, + torch.float16 if use_fp16_inputs else torch.float32, merge_encoder_and_decoder_init, no_beam_search_op, output_qk, @@ -384,7 +399,7 @@ def export_onnx_models( PROVIDERS[provider], verbose, use_external_data_format, - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, use_encoder_hidden_states=(name == "decoder_init"), use_kv_cache_inputs=(name == "decoder"), @@ -430,27 +445,43 @@ def export_onnx_models( model.verify_onnx( onnx_path, PROVIDERS[provider], - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, ) else: model.verify_onnx( onnx_path, PROVIDERS[provider], - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, ) - if precision == Precision.INT8: - quantization.quantize_dynamic( - onnx_path, + if precision in (Precision.INT8, Precision.INT4): + onnx_model = onnx.load(onnx_path, load_external_data=True) + matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"] + quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes) + + quant = MatMulNBitsQuantizer( + model=onnx_model, + block_size=32, + is_symmetric=quantize_symmetric, + accuracy_level=accuracy_level, + quant_format=QuantFormat.QOperator, + op_types_to_quantize=("MatMul",), + algo_config=quant_algo_config, + ) + quant.process() + if os.path.exists(output_path): + os.remove(output_path) + if os.path.exists(output_path + ".data"): + os.remove(output_path + ".data") + onnx.save_model( + quant.model.model, output_path, - op_types_to_quantize=( - ["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"] - ), - use_external_data_format=use_external_data_format, - per_channel=quantize_per_channel, - reduce_range=quantize_reduce_range, - extra_options={"MatMulConstBOnly": True}, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=os.path.basename(output_path) + ".data", + size_threshold=0, + convert_attribute=False, ) else: logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") @@ -493,9 +524,8 @@ def main(argv=None): args.output_cross_qk, args.overwrite, not args.use_int64_inputs, - args.quantize_embedding_layer, - args.quantize_per_channel, - args.quantize_reduce_range, + args.accuracy_level, + args.quantize_symmetric, args.provider, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 37fc72cd26e07..37b23d9daabf4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.7.0 +torch==2.7.0 transformers==4.52.3 openai-whisper==20240927 ffmpeg-python diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 365a69ee4ec67..c28fa06e13c76 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -54,16 +54,19 @@ def chain_model(args): config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + use_fp16_inputs = args.precision == Precision.FLOAT16 or ( + args.precision in (Precision.INT8, Precision.INT4) and args.use_gpu + ) # Create inputs/outputs for WhisperBeamSearch op - temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" + temperature_name = "temperature_fp16" if use_fp16_inputs else "temperature" beam_inputs = [ - "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", + "input_features_fp16" if use_fp16_inputs else "input_features", "max_length", "min_length", "num_beams", "num_return_sequences", - "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", - "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", + "length_penalty_fp16" if use_fp16_inputs else "length_penalty", + "repetition_penalty_fp16" if use_fp16_inputs else "repetition_penalty", "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask @@ -74,8 +77,8 @@ def chain_model(args): temperature_name if args.use_temperature else "", ] - sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" - scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + sequence_scores_name = "sequence_scores_fp16" if use_fp16_inputs else "sequence_scores" + scores_name = "scores_fp16" if use_fp16_inputs else "scores" beam_outputs = [ "sequences", sequence_scores_name if args.output_sequence_scores else "", @@ -85,7 +88,7 @@ def chain_model(args): ] graph_nodes = [] - if args.precision == Precision.FLOAT16: + if use_fp16_inputs: input_features_cast_node = helper.make_node( "Cast", inputs=["input_features"], diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index a53ddbf500ffa..f76a6e036c661 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -8,6 +8,7 @@ import os import shutil import unittest +from importlib.util import find_spec import onnx import pytest @@ -20,12 +21,16 @@ from benchmark_helper import Precision from convert_generation import main as run from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models - from models.whisper.convert_to_onnx import main as run_whisper + + if not find_spec("onnxruntime.training"): + from models.whisper.convert_to_onnx import main as run_whisper else: from onnxruntime.transformers.benchmark_helper import Precision from onnxruntime.transformers.convert_generation import main as run from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models - from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper + + if not find_spec("onnxruntime.training"): + from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper def has_cuda_environment(): @@ -464,7 +469,7 @@ def setUp(self): self.int8_cpu_arguments = [ "--precision", "int8", - "--quantize_embedding_layer", + "--quantize_symmetric", ] def tearDown(self): @@ -509,21 +514,33 @@ def run_configs(self, optional_arguments): if "--model_impl" not in arguments: self.run_export(arguments) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_required_args(self): optional_args = [] self.run_configs(optional_args) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_forced_decoder_ids(self): decoder_input_ids = ["--use_forced_decoder_ids"] self.run_configs(decoder_input_ids) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_logits_processor(self): logits_processor = ["--use_logits_processor"] self.run_configs(logits_processor) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_cross_qk_overall(self): cross_qk_input_args = [ @@ -540,6 +557,9 @@ def test_cross_qk_overall(self): ] self.run_configs(cross_qk_input_args + cross_qk_output_args) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_openai_impl_whisper(self): optional_args = ["--model_impl", "openai"]