Skip to content

Commit

Permalink
ONNX converter and onnxruntime based transcriber
Browse files Browse the repository at this point in the history
  • Loading branch information
keveman committed Oct 23, 2024
1 parent 2e04ed3 commit 888f729
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
74 changes: 74 additions & 0 deletions moonshine/tools/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import sys
import keras
import moonshine
from pathlib import Path


def convert_and_store(model, input_signature, output_file):
from tf2onnx.convert import from_keras
import onnx

onnx_model, external_storage_dict = from_keras(
model, input_signature=input_signature
)
assert external_storage_dict is None, f"External storage for onnx not supported"
onnx.save_model(onnx_model, output_file)


def main():
assert (
len(sys.argv) == 3
), "Usage: convert_to_onnx.py <moonshine model name> <output directory name>"
assert (
keras.config.backend() == "tensorflow"
), "Should be run with the tensorflow backend"

import tensorflow as tf

model_name = sys.argv[1]
model = moonshine.load_model(model_name)
output_dir = sys.argv[2]
Path(output_dir).mkdir(parents=True, exist_ok=True)

convert_and_store(
model.preprocessor.preprocess,
input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)],
output_file=f"{output_dir}/preprocess.onnx",
)

seq_len_spec = tf.TensorSpec([1], dtype=tf.int32)

convert_and_store(
model.encoder.encoder,
input_signature=[
tf.TensorSpec([None, None, model.dim], dtype=tf.float32),
seq_len_spec,
],
output_file=f"{output_dir}/encode.onnx",
)

input_spec = tf.TensorSpec([None, None], dtype=tf.int32)
context_spec = tf.TensorSpec([None, None, model.dim], dtype=tf.float32)
cache_spec = [
tf.TensorSpec(
[None, None, model.n_head, model.inner_dim // model.n_head],
dtype=tf.float32,
)
for _ in range(model.dec_n_layers * 4)
]

convert_and_store(
model.decoder.uncached_call,
input_signature=[input_spec, context_spec, seq_len_spec],
output_file=f"{output_dir}/uncached_decode.onnx",
)

convert_and_store(
model.decoder.cached_call,
input_signature=[input_spec, context_spec, seq_len_spec] + cache_spec,
output_file=f"{output_dir}/cached_decode.onnx",
)


if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions moonshine/tools/onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import onnxruntime
import moonshine


class MoonshineOnnxModel(object):
def __init__(self, models_dir):
self.preprocess = onnxruntime.InferenceSession(f"{models_dir}/preprocess.onnx")
self.encode = onnxruntime.InferenceSession(f"{models_dir}/encode.onnx")
self.uncached_decode = onnxruntime.InferenceSession(
f"{models_dir}/uncached_decode.onnx"
)
self.cached_decode = onnxruntime.InferenceSession(
f"{models_dir}/cached_decode.onnx"
)
self.tokenizer = moonshine.load_tokenizer()

def generate(self, audio, max_len=None):
audio = moonshine.load_audio(audio, return_numpy=True)
if max_len is None:
# max 6 tokens per second of audio
max_len = int((audio.shape[-1] / 16_000) * 6)
preprocessed = self.preprocess.run([], dict(args_0=audio))[0]
seq_len = [preprocessed.shape[-2]]

context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0]
inputs = [[1]]
seq_len = [1]

tokens = [1]
logits, *cache = self.uncached_decode.run(
[], dict(args_0=inputs, args_1=context, args_2=seq_len)
)
for i in range(max_len):
next_token = logits.squeeze().argmax()
tokens.extend([next_token])
if next_token == 2:
break

seq_len[0] += 1
inputs = [[next_token]]
logits, *cache = self.cached_decode.run(
[],
dict(
args_0=inputs,
args_1=context,
args_2=seq_len,
**{f"args_{i+3}": x for i, x in enumerate(cache)},
),
)
return [tokens]

0 comments on commit 888f729

Please sign in to comment.