-
Notifications
You must be signed in to change notification settings - Fork 306
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update docs for torch-directml 0.2.2 (#593)
* update docs for next torch-directml release * Minor readme spacing issues --------- Co-authored-by: Sheil Kumar <[email protected]> Co-authored-by: Dwayne Robinson <[email protected]>
- Loading branch information
1 parent
4d65cad
commit 372a622
Showing
32 changed files
with
106,715 additions
and
634 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 OpenAI | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Speech Recognition with Whisper | ||
This sample guides you on how to run OpenAI's automatic speech recognition (ASR) [Whisper model](https://github.com/openai/whisper/blob/main/README.md) with our DirectML-backend. | ||
|
||
- [Setup](#setup) | ||
- [About Whisper](#run-the-whisper-model) | ||
- [Basic Settings](#basic-settings) | ||
- [External Links](#external-links) | ||
- [Model License](#model-license) | ||
|
||
|
||
## About Whisper | ||
|
||
The [OpenAI Whisper](https://github.com/openai/whisper/) model is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. | ||
|
||
Whisper supports five model sizes, four with English-only versions and all five with multilingual versions. | ||
| Size | Parameters | English-only model | Multilingual model | Required VRAM | ||
|:---------:|:----------:|:------------------:|:------------------:|:-------------:| | ||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | | ||
| base | 74 M | `base.en` | `base` | ~1 GB | | ||
| small | 244 M | `small.en` | `small` | ~2 GB | | ||
| medium | 769 M | `medium.en` | `medium` | ~5 GB | | ||
| large v3 | 1550 M | N/A | `large-v3` | ~10 GB | | ||
|
||
For more information on the model, please refer to the [OpenAI Whisper GitHub repo](https://github.com/openai/whisper/). | ||
|
||
|
||
## Setup | ||
Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running the app: | ||
|
||
|
||
``` | ||
conda install ffmpeg | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
## Run the Whisper model | ||
Run Whisper with DirectML backend with a sample audio file with the following command: | ||
```bash | ||
python run.py --input_file <audio_file> --model_size "tiny.en" | ||
``` | ||
|
||
|
||
For example, you should see the result output as below: | ||
``` | ||
> python run.py --input_file test/samples_jfk.wav --model_size "tiny.en" | ||
100%|█████████████████████████████████████| 72.1M/72.1M [00:09<00:00, 7.90MiB/s] | ||
test/samples_jfk.wav | ||
And so my fellow Americans ask not what your country can do for you ask what you can do for your country. | ||
``` | ||
|
||
|
||
Note, by default [OpenAI Whisper](https://github.com/openai/whisper/) uses a naive implementation for the scaled dot product attention. If you want to improve performance further to leverage DirectML's scaled dot product attention, execute `run.py` with `--use_dml_attn` flag: | ||
|
||
```bash | ||
python run.py --input_file <audio_file> --model_size "tiny.en" --use_dml_attn | ||
``` | ||
Based on this flag `MultiHeadAttention` module in `model.py` would choose between naive whisper scaled dot product attention and DirectML's scaled dot product attention: | ||
```python | ||
if use_dml_attn: | ||
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention) | ||
else: | ||
wv, qk = self.qkv_attention(q, k, v, mask) | ||
``` | ||
|
||
## Basic Settings | ||
|
||
Following is a list of the basic settings supported by `run.py`: | ||
|
||
|
||
|
||
| Flag | Description | Default | | ||
| ---------------------- | ------------------------------------------------------------ | ------- | | ||
| `--help` | Show this help message. | - | | ||
| `--input_file` | [Required] Path to input audio file | - | | ||
| `--model_size` | Size of Whisper model to use. Options: [`tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, `large-v3`] | `tiny.en` | | ||
| `--fp16` | Runs inference with fp16 precision. | True | | ||
| `--use_dml_attn` | Runs inference with DirectML Scaled dot product attention impl. | False | | ||
|
||
|
||
## External Links | ||
- [Whisper Base Hugging Face Repository](https://huggingface.co/openai/whisper-base.en) | ||
- [Whisper Tiny Hugging Face Repository](https://huggingface.co/openai/whisper-tiny.en) | ||
- [Whisper Small Hugging Face Repository](https://huggingface.co/openai/whisper-small.en) | ||
- [Whisper Medium Hugging Face Repository](https://huggingface.co/openai/whisper-medium.en) | ||
- [Whisper Large v3 Hugging Face Repository](https://huggingface.co/openai/whisper-large-v3) | ||
- [Whisper GitHub Repo](https://github.com/openai/whisper) | ||
|
||
|
||
|
||
## Model License | ||
|
||
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
numba | ||
numpy | ||
tqdm | ||
more-itertools | ||
tiktoken | ||
ffmpeg-python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import whisper | ||
import torch_directml | ||
import argparse | ||
|
||
|
||
def main(args): | ||
device = torch_directml.device(torch_directml.default_device()) | ||
model = whisper.load_model(args.model_size, device=device, use_dml_attn=args.use_dml_attn) | ||
|
||
# Load audio and pad/trim it to fit 30 seconds | ||
audio = whisper.load_audio(args.input_file) | ||
audio = whisper.pad_or_trim(audio) | ||
|
||
n_mels = 80 | ||
if args.model_size == "large-v3": | ||
n_mels = 128 | ||
|
||
mel = whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device) | ||
language = "en" | ||
if "en" not in args.model_size: | ||
_, probs = model.detect_language(mel) | ||
language = max(probs, key=probs.get) | ||
print(f"Detected language: {language}") | ||
|
||
options = whisper.DecodingOptions(language=language, fp16=args.fp16) | ||
result = whisper.decode(model, mel, options) | ||
|
||
print(result.text) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Run Whisper model on specified audio file with warmup.') | ||
parser.add_argument('--model_size', type=str, default='tiny.en', help='Size of the Whisper model to use.') | ||
parser.add_argument('--input_file', type=str, required=True, help='Path to the input audio file.') | ||
parser.add_argument('--fp16', action="store_true", help='Runs inference with fp16 precision.') | ||
parser.add_argument('--use_dml_attn', action="store_true", help='Use DirectML attention implementation.') | ||
args = parser.parse_args() | ||
|
||
main(args) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import hashlib | ||
import io | ||
import os | ||
import urllib | ||
import warnings | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim | ||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language | ||
from .model import ModelDimensions, Whisper | ||
from .transcribe import transcribe | ||
# from .version import __version__ | ||
|
||
_MODELS = { | ||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", | ||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", | ||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", | ||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", | ||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", | ||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", | ||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", | ||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", | ||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", | ||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | ||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", | ||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", | ||
} | ||
|
||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are | ||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. | ||
_ALIGNMENT_HEADS = { | ||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", | ||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", | ||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", | ||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", | ||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", | ||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", | ||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00", | ||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", | ||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", | ||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", | ||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", | ||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", | ||
} | ||
|
||
|
||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: | ||
os.makedirs(root, exist_ok=True) | ||
|
||
expected_sha256 = url.split("/")[-2] | ||
download_target = os.path.join(root, os.path.basename(url)) | ||
|
||
if os.path.exists(download_target) and not os.path.isfile(download_target): | ||
raise RuntimeError(f"{download_target} exists and is not a regular file") | ||
|
||
if os.path.isfile(download_target): | ||
with open(download_target, "rb") as f: | ||
model_bytes = f.read() | ||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | ||
return model_bytes if in_memory else download_target | ||
else: | ||
warnings.warn( | ||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" | ||
) | ||
|
||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | ||
with tqdm( | ||
total=int(source.info().get("Content-Length")), | ||
ncols=80, | ||
unit="iB", | ||
unit_scale=True, | ||
unit_divisor=1024, | ||
) as loop: | ||
while True: | ||
buffer = source.read(8192) | ||
if not buffer: | ||
break | ||
|
||
output.write(buffer) | ||
loop.update(len(buffer)) | ||
|
||
model_bytes = open(download_target, "rb").read() | ||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: | ||
raise RuntimeError( | ||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." | ||
) | ||
|
||
return model_bytes if in_memory else download_target | ||
|
||
|
||
def available_models() -> List[str]: | ||
"""Returns the names of available models""" | ||
return list(_MODELS.keys()) | ||
|
||
|
||
def load_model( | ||
name: str, | ||
device: Optional[Union[str, torch.device]] = None, | ||
download_root: str = None, | ||
in_memory: bool = False, | ||
use_dml_attn: bool = False, | ||
) -> Whisper: | ||
""" | ||
Load a Whisper ASR model | ||
Parameters | ||
---------- | ||
name : str | ||
one of the official model names listed by `whisper.available_models()`, or | ||
path to a model checkpoint containing the model dimensions and the model state_dict. | ||
device : Union[str, torch.device] | ||
the PyTorch device to put the model into | ||
download_root: str | ||
path to download the model files; by default, it uses "~/.cache/whisper" | ||
in_memory: bool | ||
whether to preload the model weights into host memory | ||
Returns | ||
------- | ||
model : Whisper | ||
The Whisper ASR model instance | ||
""" | ||
|
||
if device is None: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if download_root is None: | ||
default = os.path.join(os.path.expanduser("~"), ".cache") | ||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") | ||
|
||
if name in _MODELS: | ||
checkpoint_file = _download(_MODELS[name], download_root, in_memory) | ||
alignment_heads = _ALIGNMENT_HEADS[name] | ||
elif os.path.isfile(name): | ||
checkpoint_file = open(name, "rb").read() if in_memory else name | ||
alignment_heads = None | ||
else: | ||
raise RuntimeError( | ||
f"Model {name} not found; available models = {available_models()}" | ||
) | ||
|
||
# with ( | ||
# io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") | ||
# ) as fp: | ||
# # checkpoint = torch.load(fp, map_location=device) | ||
# checkpoint = torch.load(fp, mmap=True, weights_only=True) | ||
# del checkpoint_file | ||
checkpoint = torch.load(checkpoint_file, mmap=True, weights_only=True) | ||
|
||
dims = ModelDimensions(**checkpoint["dims"]) | ||
model = Whisper(dims, use_dml_attn=use_dml_attn) | ||
|
||
|
||
model.load_state_dict(checkpoint["model_state_dict"]) | ||
if alignment_heads is not None: | ||
model.set_alignment_heads(alignment_heads) | ||
|
||
return model.to(device) |
Oops, something went wrong.