From 888f72910edb43ae3aa4d88fb86acd3e983399b9 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Wed, 23 Oct 2024 16:20:58 -0700 Subject: [PATCH] ONNX converter and onnxruntime based transcriber --- moonshine/tools/convert_to_onnx.py | 74 ++++++++++++++++++++++++++++++ moonshine/tools/onnx_model.py | 50 ++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 moonshine/tools/convert_to_onnx.py create mode 100644 moonshine/tools/onnx_model.py diff --git a/moonshine/tools/convert_to_onnx.py b/moonshine/tools/convert_to_onnx.py new file mode 100644 index 0000000..d4c2182 --- /dev/null +++ b/moonshine/tools/convert_to_onnx.py @@ -0,0 +1,74 @@ +import sys +import keras +import moonshine +from pathlib import Path + + +def convert_and_store(model, input_signature, output_file): + from tf2onnx.convert import from_keras + import onnx + + onnx_model, external_storage_dict = from_keras( + model, input_signature=input_signature + ) + assert external_storage_dict is None, f"External storage for onnx not supported" + onnx.save_model(onnx_model, output_file) + + +def main(): + assert ( + len(sys.argv) == 3 + ), "Usage: convert_to_onnx.py " + assert ( + keras.config.backend() == "tensorflow" + ), "Should be run with the tensorflow backend" + + import tensorflow as tf + + model_name = sys.argv[1] + model = moonshine.load_model(model_name) + output_dir = sys.argv[2] + Path(output_dir).mkdir(parents=True, exist_ok=True) + + convert_and_store( + model.preprocessor.preprocess, + input_signature=[tf.TensorSpec([None, None], dtype=tf.float32)], + output_file=f"{output_dir}/preprocess.onnx", + ) + + seq_len_spec = tf.TensorSpec([1], dtype=tf.int32) + + convert_and_store( + model.encoder.encoder, + input_signature=[ + tf.TensorSpec([None, None, model.dim], dtype=tf.float32), + seq_len_spec, + ], + output_file=f"{output_dir}/encode.onnx", + ) + + input_spec = tf.TensorSpec([None, None], dtype=tf.int32) + context_spec = tf.TensorSpec([None, None, model.dim], dtype=tf.float32) + cache_spec = [ + tf.TensorSpec( + [None, None, model.n_head, model.inner_dim // model.n_head], + dtype=tf.float32, + ) + for _ in range(model.dec_n_layers * 4) + ] + + convert_and_store( + model.decoder.uncached_call, + input_signature=[input_spec, context_spec, seq_len_spec], + output_file=f"{output_dir}/uncached_decode.onnx", + ) + + convert_and_store( + model.decoder.cached_call, + input_signature=[input_spec, context_spec, seq_len_spec] + cache_spec, + output_file=f"{output_dir}/cached_decode.onnx", + ) + + +if __name__ == "__main__": + main() diff --git a/moonshine/tools/onnx_model.py b/moonshine/tools/onnx_model.py new file mode 100644 index 0000000..2f82b0c --- /dev/null +++ b/moonshine/tools/onnx_model.py @@ -0,0 +1,50 @@ +import onnxruntime +import moonshine + + +class MoonshineOnnxModel(object): + def __init__(self, models_dir): + self.preprocess = onnxruntime.InferenceSession(f"{models_dir}/preprocess.onnx") + self.encode = onnxruntime.InferenceSession(f"{models_dir}/encode.onnx") + self.uncached_decode = onnxruntime.InferenceSession( + f"{models_dir}/uncached_decode.onnx" + ) + self.cached_decode = onnxruntime.InferenceSession( + f"{models_dir}/cached_decode.onnx" + ) + self.tokenizer = moonshine.load_tokenizer() + + def generate(self, audio, max_len=None): + audio = moonshine.load_audio(audio, return_numpy=True) + if max_len is None: + # max 6 tokens per second of audio + max_len = int((audio.shape[-1] / 16_000) * 6) + preprocessed = self.preprocess.run([], dict(args_0=audio))[0] + seq_len = [preprocessed.shape[-2]] + + context = self.encode.run([], dict(args_0=preprocessed, args_1=seq_len))[0] + inputs = [[1]] + seq_len = [1] + + tokens = [1] + logits, *cache = self.uncached_decode.run( + [], dict(args_0=inputs, args_1=context, args_2=seq_len) + ) + for i in range(max_len): + next_token = logits.squeeze().argmax() + tokens.extend([next_token]) + if next_token == 2: + break + + seq_len[0] += 1 + inputs = [[next_token]] + logits, *cache = self.cached_decode.run( + [], + dict( + args_0=inputs, + args_1=context, + args_2=seq_len, + **{f"args_{i+3}": x for i, x in enumerate(cache)}, + ), + ) + return [tokens]