Skip to content

Commit

Permalink
predict: refactor, add support for yolov8 on tflite
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Jun 16, 2023
1 parent b10b4d0 commit 2b9a0f0
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 88 deletions.
4 changes: 2 additions & 2 deletions plugins/coreml/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/coreml/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.18"
"version": "0.1.19"
}
2 changes: 1 addition & 1 deletion plugins/coreml/src/coreml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, nativeId: str | None = None):
if model == "Default":
# model = "ssdlite_mobilenet_v2"
if "arm" in platform.processor():
model = "yolov8"
model = "yolov8n"
else:
model = "ssdlite_mobilenet_v2"
self.yolo = "yolo" in model
Expand Down
4 changes: 2 additions & 2 deletions plugins/openvino/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
// docker installation
// "scrypted.debugHost": "koushik-ubuntu",
// "scrypted.serverRoot": "/server",
"scrypted.serverRoot": "/server",

// pi local installation
// "scrypted.debugHost": "192.168.2.119",
Expand All @@ -12,7 +12,7 @@
// "scrypted.debugHost": "127.0.0.1",
// "scrypted.serverRoot": "/Users/koush/.scrypted",
"scrypted.debugHost": "koushik-windows",
"scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",
// "scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",

"scrypted.pythonRemoteRoot": "${config:scrypted.serverRoot}/volume/plugin.zip",
"python.analysis.extraPaths": [
Expand Down
14 changes: 7 additions & 7 deletions plugins/openvino/src/yolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@

from predict import Prediction, Rectangle

def parse_yolov8(results):
def parse_yolov8(results, scale = 1):
objs = []
keep = np.argwhere(results[4:] > 0.2)
keep = np.argwhere(results[4:] > .2)
for indices in keep:
class_id = indices[0]
index = indices[1]
confidence = results[class_id + 4, index]
x = results[0][index].astype(float)
y = results[1][index].astype(float)
w = results[2][index].astype(float)
h = results[3][index].astype(float)
x = results[0][index].astype(float) * scale
y = results[1][index].astype(float) * scale
w = results[2][index].astype(float) * scale
h = results[3][index].astype(float) * scale
obj = Prediction(
class_id,
int(class_id),
confidence.astype(float),
Rectangle(
x - w / 2,
Expand Down
8 changes: 4 additions & 4 deletions plugins/tensorflow-lite/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
{
// docker installation
// "scrypted.debugHost": "koushik-ubuntu",
// "scrypted.serverRoot": "/server",
"scrypted.serverRoot": "/server",

// pi local installation
// "scrypted.debugHost": "192.168.2.119",
// "scrypted.serverRoot": "/home/pi/.scrypted",

// local checkout
"scrypted.debugHost": "127.0.0.1",
"scrypted.serverRoot": "/Users/koush/.scrypted",
// "scrypted.debugHost": "koushik-windows",
// "scrypted.debugHost": "127.0.0.1",
// "scrypted.serverRoot": "/Users/koush/.scrypted",
"scrypted.debugHost": "koushik-windows",
// "scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",

"scrypted.pythonRemoteRoot": "${config:scrypted.serverRoot}/volume/plugin.zip",
Expand Down
4 changes: 2 additions & 2 deletions plugins/tensorflow-lite/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/tensorflow-lite/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.17"
"version": "0.1.19"
}
228 changes: 160 additions & 68 deletions plugins/tensorflow-lite/src/tflite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,166 @@
from __future__ import annotations
from .common import *

from PIL import Image
from pycoral.adapters import detect

from .common import *

loaded_py_coral = False
try:
from pycoral.utils.edgetpu import list_edge_tpus
from pycoral.utils.edgetpu import make_interpreter
from pycoral.utils.edgetpu import list_edge_tpus, make_interpreter

loaded_py_coral = True
print('coral edge tpu library loaded successfully')
print("coral edge tpu library loaded successfully")
except Exception as e:
print('coral edge tpu library load failed', e)
print("coral edge tpu library load failed", e)
pass
import tflite_runtime.interpreter as tflite
import asyncio
import concurrent.futures
import queue
import re
import scrypted_sdk
from scrypted_sdk.types import Setting
import traceback
from typing import Any, Tuple

import scrypted_sdk
import tflite_runtime.interpreter as tflite
from scrypted_sdk.types import Setting, SettingValue

import yolo
from predict import PredictPlugin
import concurrent.futures
import queue
import asyncio


def parse_label_contents(contents: str):
lines = contents.splitlines()
ret = {}
for row_number, content in enumerate(lines):
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
pair = re.split(r"[:\s]+", content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
ret[int(pair[0])] = pair[1].strip()
else:
ret[row_number] = content.strip()
return ret

class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Settings):

class TensorFlowLitePlugin(
PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Settings
):
def __init__(self, nativeId: str | None = None):
super().__init__(nativeId=nativeId)

tfliteFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess.tflite')
edgetpuFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite')
labelsFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/coco_labels.txt', 'coco_labels.txt')
edge_tpus = None
try:
edge_tpus = list_edge_tpus()
print("edge tpus", edge_tpus)
if not len(edge_tpus):
raise Exception("no edge tpu found")
except Exception as e:
print("unable to use Coral Edge TPU", e)
edge_tpus = None
pass

model = self.storage.getItem("model") or "Default"
if model == "Default":
if edge_tpus:
model = "yolov8n_full_integer_quant"
else:
model = "ssd_mobilenet_v2_coco_quant_postprocess"
self.yolo = "yolo" in model
self.yolov8 = "yolov8" in model

print(f'model: {model}')

model_version = "v5"

if self.yolo:
labelsFile = self.downloadFile(
"https://raw.githubusercontent.com/koush/tflite-models/main/coco_80cl.txt",
f"{model_version}/coco_80cl.txt",
)
else:
labelsFile = self.downloadFile(
"https://raw.githubusercontent.com/koush/tflite-models/main/coco_labels.txt",
f"{model_version}/coco_labels.txt",
)

labels_contents = open(labelsFile, 'r').read()
labels_contents = open(labelsFile, "r").read()
self.labels = parse_label_contents(labels_contents)
self.interpreters = queue.Queue()
self.interpreter_count = 0

def downloadModel():
return self.downloadFile(
f"https://github.com/koush/tflite-models/raw/main/{model}/{model}{suffix}.tflite",
f"{model_version}/{model}{suffix}.tflite",
)

try:
edge_tpus = list_edge_tpus()
print('edge tpus', edge_tpus)
if not len(edge_tpus):
raise Exception('no edge tpu found')
self.edge_tpu_found = str(edge_tpus)
# todo co-compile
# https://coral.ai/docs/edgetpu/compiler/#co-compiling-multiple-models
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
for idx, edge_tpu in enumerate(edge_tpus):
try:
interpreter = make_interpreter(edgetpuFile, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
print('added tpu %s' % (edge_tpu))
except Exception as e:
print('unable to use Coral Edge TPU', e)

if not self.interpreter_count:
raise Exception('all tpus failed to load')
# self.face_interpreter = make_interpreter(face_model)
if edge_tpus:
suffix = "_edgetpu"
modelFile = downloadModel()
self.edge_tpu_found = str(edge_tpus)
for idx, edge_tpu in enumerate(edge_tpus):
try:
interpreter = make_interpreter(modelFile, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[0][
"shape"
]
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
print("added tpu %s" % (edge_tpu))
except Exception as e:
print("unable to use Coral Edge TPU", e)

if not self.interpreter_count:
raise Exception("all tpus failed to load")
else:
raise Exception()
except Exception as e:
print('unable to use Coral Edge TPU', e)
self.edge_tpu_found = 'Edge TPU not found'
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
interpreter = tflite.Interpreter(model_path=tfliteFile)
self.edge_tpu_found = "Edge TPU not found"
suffix = ""
modelFile = downloadModel()
interpreter = tflite.Interpreter(model_path=modelFile)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
_, height, width, channels = interpreter.get_input_details()[0]["shape"]
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
# self.face_interpreter = make_interpreter(face_model)

self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.interpreter_count, thread_name_prefix="tflite", )
print(modelFile, labelsFile)

self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.interpreter_count,
thread_name_prefix="tflite",
)

async def putSetting(self, key: str, value: SettingValue):
self.storage.setItem(key, value)
await self.onDeviceEvent(scrypted_sdk.ScryptedInterface.Settings.value, None)
await scrypted_sdk.deviceManager.requestRestart()

async def getSettings(self) -> list[Setting]:
coral: Setting = {
'title': 'Detected Edge TPU',
'description': 'The device paths of the Coral Edge TPUs that will be used for detections.',
'value': self.edge_tpu_found,
'readonly': True,
'key': 'coral',
}
return [coral]
model = self.storage.getItem("model") or "Default"
return [
{
"title": "Detected Edge TPU",
"description": "The device paths of the Coral Edge TPUs that will be used for detections.",
"value": self.edge_tpu_found,
"readonly": True,
"key": "coral",
},
{
"key": "model",
"title": "Model",
"description": "The detection model used to find objects.",
"choices": [
"Default",
"ssd_mobilenet_v2_coco_quant_postprocess",
"yolov8n_full_integer_quant",
],
"value": model,
},
]

# width, height, channels
def get_input_details(self) -> Tuple[int, int, int]:
Expand All @@ -108,17 +173,44 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
def predict():
interpreter = self.interpreters.get()
try:
common.set_input(
interpreter, input)
scale = (1, 1)
# _, scale = common.set_resized_input(
# self.interpreter, cropped.size, lambda size: cropped.resize(size, Image.ANTIALIAS))
interpreter.invoke()
objs = detect.get_objects(
interpreter, score_threshold=.2, image_scale=scale)
if self.yolo:
tensor_index = input_details(interpreter, 'index')

im = np.stack([input])
i = interpreter.get_input_details()[0]
if i['dtype'] == np.int8:
im = im.view(np.int8)
else:
im = im.astype(np.float32) / 255.0
interpreter.set_tensor(tensor_index, im)
interpreter.invoke()
output_details = interpreter.get_output_details()
y = []
for output in output_details:
x = interpreter.get_tensor(output['index'])
if output['dtype'] == np.int8:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)

if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
objs = yolo.parse_yolov8(y[0][0], scale=640)
else:
common.set_input(interpreter, input)
interpreter.invoke()
objs = detect.get_objects(
interpreter, score_threshold=0.2, image_scale=(1, 1)
)
return objs
except:
print('tensorflow-lite encountered an error while detecting. requesting plugin restart.')
traceback.print_exc()
print(
"tensorflow-lite encountered an error while detecting. requesting plugin restart."
)
self.requestRestart()
raise e
finally:
Expand Down
Loading

0 comments on commit 2b9a0f0

Please sign in to comment.