Skip to content

Commit

Permalink
coreml: yolov8 default on apple silicon
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Jun 15, 2023
1 parent d9a575c commit 6f7fa54
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions plugins/coreml/src/coreml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import os
import platform
import re
import scrypted_sdk
from scrypted_sdk import SettingValue, Setting
from typing import Any, Tuple
from predict import PredictPlugin, Prediction, Rectangle

import coremltools as ct
import os
import numpy as np
import scrypted_sdk
from PIL import Image
import asyncio
import concurrent.futures
from scrypted_sdk import Setting, SettingValue

import yolo
import numpy as np
from predict import Prediction, PredictPlugin, Rectangle

predictExecutor = concurrent.futures.ThreadPoolExecutor(8, "CoreML-Predict")

Expand All @@ -33,12 +37,16 @@ def __init__(self, nativeId: str | None = None):

model = self.storage.getItem("model") or "Default"
if model == "Default":
model = "ssdlite_mobilenet_v2"
self.yolo = 'yolo' in model
self.yolov8 = 'yolov8' in model
# model = "ssdlite_mobilenet_v2"
if "arm" in platform.processor():
model = "yolov8"
else:
model = "ssdlite_mobilenet_v2"
self.yolo = "yolo" in model
self.yolov8 = "yolov8" in model
model_version = "v1"

print(f'model: {model}')
print(f"model: {model}")

if not self.yolo:
# todo convert these to mlpackage
Expand All @@ -52,10 +60,10 @@ def __init__(self, nativeId: str | None = None):
)
else:
if self.yolov8:
modelFile = self.downloadFile(
f"https://github.com/koush/coreml-models/raw/main/{model}/{model}.mlmodel",
f"{model}.mlmodel",
)
modelFile = self.downloadFile(
f"https://github.com/koush/coreml-models/raw/main/{model}/{model}.mlmodel",
f"{model}.mlmodel",
)
else:
files = [
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
Expand Down Expand Up @@ -125,7 +133,7 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):

# run in executor if this is the plugin loop
if self.yolo:
input_name = 'image' if self.yolov8 else 'input_1'
input_name = "image" if self.yolov8 else "input_1"
if asyncio.get_event_loop() is self.loop:
out_dict = await asyncio.get_event_loop().run_in_executor(
predictExecutor, lambda: self.model.predict({input_name: input})
Expand All @@ -135,7 +143,7 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):

if self.yolov8:
out_blob = out_dict["var_914"]
var_914 = out_dict['var_914']
var_914 = out_dict["var_914"]
results = var_914[0]
keep = np.argwhere(results[4:] > 0.2)
for indices in keep:
Expand Down

0 comments on commit 6f7fa54

Please sign in to comment.