Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 109 additions & 22 deletions homeassistant/components/image_processing/facebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@
import requests
import voluptuous as vol

from homeassistant.const import ATTR_NAME
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_NAME)
from homeassistant.core import split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.components.image_processing import (
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
CONF_ENTITY_ID, CONF_NAME)
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)

_LOGGER = logging.getLogger(__name__)

ATTR_BOUNDING_BOX = 'bounding_box'
ATTR_CLASSIFIER = 'classifier'
ATTR_IMAGE_ID = 'image_id'
ATTR_MATCHED = 'matched'
CLASSIFIER = 'facebox'
DATA_FACEBOX = 'facebox_classifiers'
EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier'
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this our standard format for event names? Have you checked other platforms or components that fire events?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a consistent format, so took my lead from EVENT_DETECT_FACE = 'image_processing.detect_face'. I think teach_classifier is pretty a accurate description, but am open to other suggestions?

FILE_PATH = 'file_path'
SERVICE_TEACH_FACE = 'facebox_teach_face'
TIMEOUT = 9


Expand All @@ -32,6 +38,12 @@
vol.Required(CONF_PORT): cv.port,
})

SERVICE_TEACH_SCHEMA = vol.Schema({
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
vol.Required(ATTR_NAME): cv.string,
vol.Required(FILE_PATH): cv.string,
})


def encode_image(image):
"""base64 encode an image stream."""
Expand Down Expand Up @@ -63,26 +75,74 @@ def parse_faces(api_faces):
return known_faces


def post_image(url, image):
"""Post an image to the classifier."""
try:
response = requests.post(
url,
json={"base64": encode_image(image)},
timeout=TIMEOUT
)
return response
except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)


def valid_file_path(file_path):
"""Check that a file_path points to a valid file."""
try:
cv.isfile(file_path)
return True
except vol.Invalid:
_LOGGER.error(
"%s error: Invalid file path: %s", CLASSIFIER, file_path)
return False


def setup_platform(hass, config, add_devices, discovery_info=None):
"""Set up the classifier."""
if DATA_FACEBOX not in hass.data:
hass.data[DATA_FACEBOX] = []

entities = []
for camera in config[CONF_SOURCE]:
entities.append(FaceClassifyEntity(
facebox = FaceClassifyEntity(
config[CONF_IP_ADDRESS],
config[CONF_PORT],
camera[CONF_ENTITY_ID],
camera.get(CONF_NAME)
))
camera.get(CONF_NAME))
entities.append(facebox)
hass.data[DATA_FACEBOX].append(facebox)
add_devices(entities)

def service_handle(service):
"""Handle for services."""
entity_ids = service.data.get('entity_id')

classifiers = hass.data[DATA_FACEBOX]
if entity_ids:
classifiers = [c for c in classifiers if c.entity_id in entity_ids]

for classifier in classifiers:
name = service.data.get(ATTR_NAME)
file_path = service.data.get(FILE_PATH)
classifier.teach(name, file_path)

hass.services.register(
DOMAIN,
SERVICE_TEACH_FACE,
service_handle,
schema=SERVICE_TEACH_SCHEMA)


class FaceClassifyEntity(ImageProcessingFaceEntity):
"""Perform a face classification."""

def __init__(self, ip, port, camera_entity, name=None):
"""Init with the API key and model id."""
super().__init__()
self._url = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER)
self._camera = camera_entity
if name:
self._name = name
Expand All @@ -94,28 +154,54 @@ def __init__(self, ip, port, camera_entity, name=None):

def process_image(self, image):
"""Process an image."""
response = {}
try:
response = requests.post(
self._url,
json={"base64": encode_image(image)},
timeout=TIMEOUT
).json()
except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
response['success'] = False

if response['success']:
total_faces = response['facesCount']
faces = parse_faces(response['faces'])
self._matched = get_matched_faces(faces)
self.process_faces(faces, total_faces)
response = post_image(self._url_check, image)
if response is not None:
response_json = response.json()
if response_json['success']:
total_faces = response_json['facesCount']
faces = parse_faces(response_json['faces'])
self._matched = get_matched_faces(faces)
self.process_faces(faces, total_faces)

else:
self.total_faces = None
self.faces = []
self._matched = {}

def teach(self, name, file_path):
"""Teach classifier a face name."""
if (not self.hass.config.is_allowed_path(file_path)
or not valid_file_path(file_path)):
return
with open(file_path, 'rb') as open_file:
response = requests.post(
self._url_teach,
data={ATTR_NAME: name, 'id': file_path},
files={'file': open_file})

if response.status_code == 200:
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': True,
'message': None
})

elif response.status_code == 400:
_LOGGER.warning(
"%s teaching of file %s failed with message:%s",
CLASSIFIER, file_path, response.text)
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': False,
'message': response.text
})

@property
def camera_entity(self):
"""Return camera entity id from process pictures."""
Expand All @@ -131,4 +217,5 @@ def device_state_attributes(self):
"""Return the classifier attributes."""
return {
'matched_faces': self._matched,
'total_matched_faces': len(self._matched),
}
13 changes: 13 additions & 0 deletions homeassistant/components/image_processing/services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@ scan:
entity_id:
description: Name(s) of entities to scan immediately.
example: 'image_processing.alpr_garage'

facebox_teach_face:
description: Teach facebox a face using a file.
fields:
entity_id:
description: The facebox entity to teach.
example: 'image_processing.facebox'
name:
description: The name of the face to teach.
example: 'my_name'
file_path:
description: The path to the image file.
example: '/images/my_image.jpg'
Loading