diff --git a/plugins/coreml/src/coreml/__init__.py b/plugins/coreml/src/coreml/__init__.py index 90360bf2ac..4d71a54231 100644 --- a/plugins/coreml/src/coreml/__init__.py +++ b/plugins/coreml/src/coreml/__init__.py @@ -1,16 +1,20 @@ from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import platform import re -import scrypted_sdk -from scrypted_sdk import SettingValue, Setting from typing import Any, Tuple -from predict import PredictPlugin, Prediction, Rectangle + import coremltools as ct -import os +import numpy as np +import scrypted_sdk from PIL import Image -import asyncio -import concurrent.futures +from scrypted_sdk import Setting, SettingValue + import yolo -import numpy as np +from predict import Prediction, PredictPlugin, Rectangle predictExecutor = concurrent.futures.ThreadPoolExecutor(8, "CoreML-Predict") @@ -33,12 +37,16 @@ def __init__(self, nativeId: str | None = None): model = self.storage.getItem("model") or "Default" if model == "Default": - model = "ssdlite_mobilenet_v2" - self.yolo = 'yolo' in model - self.yolov8 = 'yolov8' in model + # model = "ssdlite_mobilenet_v2" + if "arm" in platform.processor(): + model = "yolov8" + else: + model = "ssdlite_mobilenet_v2" + self.yolo = "yolo" in model + self.yolov8 = "yolov8" in model model_version = "v1" - print(f'model: {model}') + print(f"model: {model}") if not self.yolo: # todo convert these to mlpackage @@ -52,10 +60,10 @@ def __init__(self, nativeId: str | None = None): ) else: if self.yolov8: - modelFile = self.downloadFile( - f"https://github.com/koush/coreml-models/raw/main/{model}/{model}.mlmodel", - f"{model}.mlmodel", - ) + modelFile = self.downloadFile( + f"https://github.com/koush/coreml-models/raw/main/{model}/{model}.mlmodel", + f"{model}.mlmodel", + ) else: files = [ f"{model}/{model}.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", @@ -125,7 +133,7 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss): # run in executor if this is the plugin loop if self.yolo: - input_name = 'image' if self.yolov8 else 'input_1' + input_name = "image" if self.yolov8 else "input_1" if asyncio.get_event_loop() is self.loop: out_dict = await asyncio.get_event_loop().run_in_executor( predictExecutor, lambda: self.model.predict({input_name: input}) @@ -135,7 +143,7 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss): if self.yolov8: out_blob = out_dict["var_914"] - var_914 = out_dict['var_914'] + var_914 = out_dict["var_914"] results = var_914[0] keep = np.argwhere(results[4:] > 0.2) for indices in keep: