Skip to content

Conversation

@petermcaughan
Copy link
Contributor

Description

The current ONNX export of Whisper utilizes hard-coded values for token_ids when configuring the BeamSearch node. This PR removes these literals and instead takes these values straight from the WhisperConfig.

Motivation and Context

Hard-coding these values can cause some parity issues when comparing to default PyTorch behavior - this change to take from WhisperConfig resolves these.

Copy link
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@petermcaughan
Copy link
Contributor Author

I notice that whisper has the following configuration: https://huggingface.co/openai/whisper-large/blob/27b5bb3f092ae34d8079fdc073b8195e6815dec7/config.json#L46

And https://huggingface.co/openai/whisper-large/blob/27b5bb3f092ae34d8079fdc073b8195e6815dec7/config.json#L26

How to handle them in beam search?

I notice that whisper has the following configuration: https://huggingface.co/openai/whisper-large/blob/27b5bb3f092ae34d8079fdc073b8195e6815dec7/config.json#L46

And https://huggingface.co/openai/whisper-large/blob/27b5bb3f092ae34d8079fdc073b8195e6815dec7/config.json#L26

How to handle them in beam search?

Good catch - I see they apply these values by setting logit values like this link to HF after obtaining logits. We could replicate this by adding attributes to the BeamSearch op and doing a similar transformation before processing logits - this could make a good PR & would be relevant to all transformer models it seems.

@petermcaughan petermcaughan merged commit d0cca91 into main Apr 6, 2023
@petermcaughan petermcaughan deleted the petermca/whisper_export_config branch April 6, 2023 18:01
hanbitmyths pushed a commit that referenced this pull request Apr 19, 2023
### Description
This PR contains fusion-level and kernel-level optimizations for
[OpenAI's Whisper](https://github.com/openai/whisper).

Some of the added optimizations include:

- Pruning of duplicate/unnecessary inputs and outputs
- Fusion support for Whisper models with or without these inputs/outputs
(e.g. with these inputs/outputs if exporting with an older official
Optimum version, without these inputs/outputs if exporting with Optimum
from source)
- Attention fusions
   - For Whisper's encoder and decoder
- Modified symbolic shape inference for present output when no past
input exists (for decoder)
- Multi-head attention fusions
   - For Whisper's decoder and decoder with past
- Packed MatMul for the 3 MatMuls excluded in multi-head attention
fusion
- Attention kernel changes
   - CPU:
      - Different Q and KV sequence lengths
      - Parallel memset for large sequence lengths
- Convert broadcast add after MatMul of Q and K (add_qk) to element-wise
add
- Separate present key-value output into present key and present value
(for multi-head attention spec)
   - CUDA:
- Use memory efficient attention compute kernel with present state (for
decoder)
- Multi-head attention kernel changes
   - CPU:
- Introduction of multi-head attention CPU kernel (previously did not
exist)
- Use AddBiasReshape instead of AddBiasTranspose when sequence length =
1 (for decoder with past)
      - Different Q, K, V input shapes
      - Pass past key and past value directly as key and value
   - CUDA:
- Use memory efficient attention compute kernel with past and/or present
state (for decoder with past)

### Usage
To use the optimizations, run the ORT transformer optimizer script as
follows:
```
$ cd onnxruntime/onnxruntime/python/tools/transformers/
$ python3 optimizer.py --input <filename>.onnx --output <filename>.onnx --model_type bart --num_heads <number of attention heads, depends on the size of the whisper model used> --hidden_size <attention hidden size, depends on the size of the whisper model used> --use_external_data_format --use_multi_head_attention
```

Once optimized, here's an example of how to run Whisper with [Hugging
Face's Optimum](https://github.com/huggingface/optimum):
```
from transformers.onnx.utils import get_preprocessor
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from optimum.pipelines import pipeline as ort_pipeline

import whisper # Installed from OpenAI's repo - setup instructions at https://github.com/openai/whisper/

directory = './whisper_opt' # Where the optimized ONNX models are located
model_name = 'openai/whisper-tiny'
device = 'cpu'

# Get pipeline
processor = get_preprocessor(model_name)
model = ORTModelForSpeechSeq2Seq.from_pretrained(
    directory,
    use_io_binding=(device == 'cuda'),
    provider='CPUExecutionProvider',
).to(device)
pipe = ort_pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=(-1 if device == 'cpu' else 0),
)

# Load audio file and run pipeline
audio = whisper.load_audio('tests/jfk.flac')
audio = whisper.pad_or_trim(audio)
outputs = pipe([audio])
print(outputs)
```

Note: In order to use these changes with Optimum, it is recommended to
use Optimum from source to have the following changes:
- huggingface/optimum#872
- huggingface/optimum#920

### Motivation and Context
This PR helps the following issues:
- #15100
- #15235
- huggingface/optimum#869 (work in progress)

This PR can be used with the other currently merged Whisper PRs:
- #15247
- #15339
- #15362
- #15365
- #15427

This PR uses changes from the following merged PRs:
- #14198
- #14146
- #14201
- #14928 (this introduced
the new multi-head attention spec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants