Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions homeassistant/components/sighthound/image_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Person detection using Sighthound cloud service."""
import io
import logging
import os

from PIL import Image, ImageDraw
import simplehound.core as hound
import voluptuous as vol

Expand All @@ -14,6 +17,7 @@
from homeassistant.const import ATTR_ENTITY_ID, CONF_API_KEY
from homeassistant.core import split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.util.pil import draw_box

_LOGGER = logging.getLogger(__name__)

Expand All @@ -22,13 +26,15 @@
ATTR_BOUNDING_BOX = "bounding_box"
ATTR_PEOPLE = "people"
CONF_ACCOUNT_TYPE = "account_type"
CONF_SAVE_FILE_FOLDER = "save_file_folder"
DEV = "dev"
PROD = "prod"

PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_API_KEY): cv.string,
vol.Optional(CONF_ACCOUNT_TYPE, default=DEV): vol.In([DEV, PROD]),
vol.Optional(CONF_SAVE_FILE_FOLDER): cv.isdir,
}
)

Expand All @@ -45,10 +51,14 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
_LOGGER.error("Sighthound error %s setup aborted", exc)
return

save_file_folder = config.get(CONF_SAVE_FILE_FOLDER)
if save_file_folder:
save_file_folder = os.path.join(save_file_folder, "") # If no trailing / add it

entities = []
for camera in config[CONF_SOURCE]:
sighthound = SighthoundEntity(
api, camera[CONF_ENTITY_ID], camera.get(CONF_NAME)
api, camera[CONF_ENTITY_ID], camera.get(CONF_NAME), save_file_folder
)
entities.append(sighthound)
add_entities(entities)
Expand All @@ -57,7 +67,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
class SighthoundEntity(ImageProcessingEntity):
"""Create a sighthound entity."""

def __init__(self, api, camera_entity, name):
def __init__(self, api, camera_entity, name, save_file_folder):
"""Init."""
self._api = api
self._camera = camera_entity
Expand All @@ -69,6 +79,8 @@ def __init__(self, api, camera_entity, name):
self._state = None
self._image_width = None
self._image_height = None
if save_file_folder:
self._save_file_folder = save_file_folder

def process_image(self, image):
"""Process an image."""
Expand All @@ -81,6 +93,8 @@ def process_image(self, image):
self._image_height = metadata["image_height"]
for person in people:
self.fire_person_detected_event(person)
if hasattr(self, "_save_file_folder") and self._state > 0:
self.save_image(image, people, self._save_file_folder)

def fire_person_detected_event(self, person):
"""Send event with detected total_persons."""
Expand All @@ -94,6 +108,21 @@ def fire_person_detected_event(self, person):
},
)

def save_image(self, image, people, directory):
"""Save a timestamped image with bounding boxes around targets."""

img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
draw = ImageDraw.Draw(img)

for person in people:
box = hound.bbox_to_tf_style(
person["boundingBox"], self._image_width, self._image_height
)
draw_box(draw, box, self._image_width, self._image_height)

latest_save_path = directory + "{}_latest.jpg".format(self._name)
img.save(latest_save_path)

@property
def camera_entity(self):
"""Return camera entity id from process pictures."""
Expand Down
18 changes: 18 additions & 0 deletions tests/components/sighthound/test_image_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the Sighthound integration."""
import os
from unittest.mock import patch

import pytest
Expand All @@ -10,6 +11,8 @@
from homeassistant.core import callback
from homeassistant.setup import async_setup_component

TEST_DIR = os.path.join(os.path.dirname(__file__))

VALID_CONFIG = {
ip.DOMAIN: {
"platform": "sighthound",
Expand Down Expand Up @@ -91,3 +94,18 @@ def capture_person_event(event):
state = hass.states.get(VALID_ENTITY_ID)
assert state.state == "2"
assert len(person_events) == 2


async def test_save_image(hass, mock_image, mock_detections):
"""Save a processed image."""
VALID_CONFIG.update({sh.CONF_SAVE_FILE_FOLDER: TEST_DIR})
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
assert hass.states.get(VALID_ENTITY_ID)

data = {ATTR_ENTITY_ID: VALID_ENTITY_ID}
with patch("PIL.Image.Image.save") as mock_save:
await hass.services.async_call(ip.DOMAIN, ip.SERVICE_SCAN, service_data=data)
await hass.async_block_till_done()
state = hass.states.get(VALID_ENTITY_ID)
assert state.state == "2"
mock_save.assert_called_with("test.jpg")