diff --git a/openwakeword/model.py b/openwakeword/model.py index 6029963..621987a 100755 --- a/openwakeword/model.py +++ b/openwakeword/model.py @@ -119,16 +119,26 @@ def tflite_predict(tflite_interpreter, input_index, output_index, x): return tflite_interpreter.get_tensor(output_index)[None, ] except ImportError: - logging.warning("Tried to import the tflite runtime, but it was not found. " - "Trying to switching to onnxruntime instead, if appropriate models are available.") - if wakeword_models != [] and all(['.onnx' in i for i in wakeword_models]): - inference_framework = "onnx" - elif wakeword_models != [] and all([os.path.exists(i.replace('.tflite', '.onnx')) for i in wakeword_models]): - inference_framework = "onnx" - wakeword_models = [i.replace('.tflite', '.onnx') for i in wakeword_models] + from importlib.util import find_spec + if find_spec("tensorflow") is not None and find_spec("tflite_runtime") is None: + logging.warning("Tried to import the tflite runtime, but it was not found. Using tensorflow instead.") + from tensorflow.lite.python import interpreter as tflite + + def tflite_predict(tflite_interpreter, input_index, output_index, x): + tflite_interpreter.set_tensor(input_index, x) + tflite_interpreter.invoke() + return tflite_interpreter.get_tensor(output_index)[None, ] else: - raise ValueError("Tried to import the tflite runtime for provided tflite models, but it was not found. " - "Please install it using `pip install tflite-runtime`") + logging.warning("Tried to import the tflite runtime, but it was not found. " + "Trying to switching to onnxruntime instead, if appropriate models are available.") + if wakeword_models != [] and all(['.onnx' in i for i in wakeword_models]): + inference_framework = "onnx" + elif wakeword_models != [] and all([os.path.exists(i.replace('.tflite', '.onnx')) for i in wakeword_models]): + inference_framework = "onnx" + wakeword_models = [i.replace('.tflite', '.onnx') for i in wakeword_models] + else: + raise ValueError("Tried to import the tflite runtime for provided tflite models, but it was not found. " + "Please install it using `pip install tflite-runtime`") if inference_framework == "onnx": try: diff --git a/openwakeword/utils.py b/openwakeword/utils.py index 4964706..e45c86c 100644 --- a/openwakeword/utils.py +++ b/openwakeword/utils.py @@ -96,8 +96,14 @@ def __init__(self, try: import tflite_runtime.interpreter as tflite except ImportError: - raise ValueError("Tried to import the TFLite runtime, but it was not found." - "Please install it using `pip install tflite-runtime`") + from importlib.util import find_spec + if find_spec("tensorflow") is not None and find_spec("tflite_runtime") is None: + logging.warning("Tried to import the tflite runtime, but it was not found. Using tensorflow instead.") + from tensorflow.lite.python import interpreter as tflite + else: + raise ValueError("Tried to import the TFLite runtime, but it was not found." + "Neither was the TensorFlow interpreter." + "Please install TFLite runtime using `pip install tflite-runtime`") if melspec_model_path == "": melspec_model_path = os.path.join(pathlib.Path(__file__).parent.resolve(),