Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions homeassistant/components/image_processing/facebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,38 @@
import requests
import voluptuous as vol

from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_NAME)
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.components.image_processing import (
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)
from homeassistant.const import (
CONF_IP_ADDRESS, CONF_PORT, CONF_PASSWORD, CONF_USERNAME,
HTTP_BAD_REQUEST, HTTP_UNAUTHORIZED)

_LOGGER = logging.getLogger(__name__)

ATTR_BOUNDING_BOX = 'bounding_box'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this as a platform variable

ATTR_CLASSIFIER = 'classifier'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really I see this in const.py

ATTR_IMAGE_ID = 'image_id'
ATTR_ID = 'id'
ATTR_MATCHED = 'matched'
ATTR_NAME = 'name'

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using ATTR_NAME in the platform services, ie for home assistant APIs, so that we can import from const.py. For the 'name' string that goes in calls to the facebox api, we should use another constant or no constant at all.

Please import ATTR_NAME from const.py and rename the constant here to something else, eg FACEBOX_NAME.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variables are getting really confusing, happy to just use name if preferred

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that we shouldn't use the same constant when interfacing with two separate APIs. One API might change and then we might change the constant to go along with the changed API, and then we would break the usage of the other API.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree

CLASSIFIER = 'facebox'
DATA_FACEBOX = 'facebox_classifiers'
EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier'
FILE_PATH = 'file_path'
NOTIFICATION_ID = 'facebox_notification'
NOTIFICATION_TITLE = 'facebox teach'
SERVICE_TEACH_FACE = 'facebox_teach_face'
TIMEOUT = 9


PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_IP_ADDRESS): cv.string,
vol.Required(CONF_PORT): cv.port,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
})

SERVICE_TEACH_SCHEMA = vol.Schema({
Expand Down Expand Up @@ -75,15 +81,44 @@ def parse_faces(api_faces):
return known_faces


def post_image(url, image):
def post_image(url, username, password, image):
"""Post an image to the classifier."""
try:
response = requests.post(
url,
auth=requests.auth.HTTPBasicAuth(username, password),
json={"base64": encode_image(image)},
timeout=TIMEOUT
)
return response
if response.status_code == HTTP_UNAUTHORIZED:
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
return None
else:
return response

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we return something here, the other places that this function can exit at are lacking a return statement. If a function returns something, all exits need to return something. Consistency is good.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


def teach_file(url, username, password, name, file_path):
"""Teach the classifier a name associated with a file."""
try:
with open(file_path, 'rb') as open_file:
response = requests.post(
url,
auth=requests.auth.HTTPBasicAuth(username, password),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if username or password is None?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are by default and its OK

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't send in auth if username and password are None, that triggers a deprecation warning.

kwargs = {}

if username:
    kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)

response = request.post(url, **kwargs)

data={ATTR_NAME: name, ATTR_ID: file_path},
files={'file': open_file},
timeout=TIMEOUT

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful with the default timeout as this is sending files

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove

)
if response.status_code == HTTP_UNAUTHORIZED:
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
return None
elif response.status_code == HTTP_BAD_REQUEST:
_LOGGER.error("%s teaching of file %s failed with message:%s",
CLASSIFIER, file_path, response.text)
return None
else:
return response

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See about consistent return.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

except requests.exceptions.ConnectionError:
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)

Expand All @@ -104,12 +139,15 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
if DATA_FACEBOX not in hass.data:
hass.data[DATA_FACEBOX] = []

ip = config[CONF_IP_ADDRESS]
port = config[CONF_PORT]
username = config.get(CONF_USERNAME)
password = config.get(CONF_PASSWORD)

entities = []
for camera in config[CONF_SOURCE]:
facebox = FaceClassifyEntity(
config[CONF_IP_ADDRESS],
config[CONF_PORT],
camera[CONF_ENTITY_ID],
ip, port, username, password, camera[CONF_ENTITY_ID],
camera.get(CONF_NAME))
entities.append(facebox)
hass.data[DATA_FACEBOX].append(facebox)
Expand All @@ -129,33 +167,33 @@ def service_handle(service):
classifier.teach(name, file_path)

hass.services.register(
DOMAIN,
SERVICE_TEACH_FACE,
service_handle,
DOMAIN, SERVICE_TEACH_FACE, service_handle,
schema=SERVICE_TEACH_SCHEMA)


class FaceClassifyEntity(ImageProcessingFaceEntity):
"""Perform a face classification."""

def __init__(self, ip, port, camera_entity, name=None):
def __init__(self, ip, port, username, password, camera_entity, name=None):
"""Init with the API key and model id."""
super().__init__()
self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER)
self._username = username
self._password = password
self._camera = camera_entity
if name:
self._name = name
else:
camera_name = split_entity_id(camera_entity)[1]
self._name = "{} {}".format(
CLASSIFIER, camera_name)
self._name = "{} {}".format(CLASSIFIER, camera_name)
self._matched = {}

def process_image(self, image):
"""Process an image."""
response = post_image(self._url_check, image)
if response is not None:
response = post_image(
self._url_check, self._username, self._password, image)
if response:
response_json = response.json()
if response_json['success']:
total_faces = response_json['facesCount']
Expand All @@ -173,34 +211,8 @@ def teach(self, name, file_path):
if (not self.hass.config.is_allowed_path(file_path)
or not valid_file_path(file_path)):
return
with open(file_path, 'rb') as open_file:
response = requests.post(
self._url_teach,
data={ATTR_NAME: name, 'id': file_path},
files={'file': open_file})

if response.status_code == 200:
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': True,
'message': None
})

elif response.status_code == 400:
_LOGGER.warning(
"%s teaching of file %s failed with message:%s",
CLASSIFIER, file_path, response.text)
self.hass.bus.fire(
EVENT_CLASSIFIER_TEACH, {
ATTR_CLASSIFIER: CLASSIFIER,
ATTR_NAME: name,
FILE_PATH: file_path,
'success': False,
'message': response.text
})
teach_file(
self._url_teach, self._username, self._password, name, file_path)

@property
def camera_entity(self):
Expand Down
41 changes: 15 additions & 26 deletions tests/components/image_processing/test_facebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from homeassistant.core import callback
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME,
CONF_IP_ADDRESS, CONF_PORT, STATE_UNKNOWN)
ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME, CONF_PASSWORD,
CONF_USERNAME, CONF_IP_ADDRESS, CONF_PORT, STATE_UNKNOWN)
from homeassistant.setup import async_setup_component
import homeassistant.components.image_processing as ip
import homeassistant.components.image_processing.facebox as fb
Expand All @@ -33,6 +33,8 @@
"faces": [MOCK_FACE]}

MOCK_NAME = 'mock_name'
MOCK_USERNAME = 'mock_username'
MOCK_PASSWORD = 'mock_password'

# Faces data after parsing.
PARSED_FACES = [{ATTR_NAME: 'John Lennon',
Expand Down Expand Up @@ -114,6 +116,17 @@ async def test_setup_platform(hass):
assert hass.states.get(VALID_ENTITY_ID)


async def test_setup_platform_with_auth(hass):
"""Setup platform with one entity and auth."""

valid_config_auth = VALID_CONFIG.copy()
valid_config_auth[ip.DOMAIN][CONF_USERNAME] = MOCK_USERNAME
valid_config_auth[ip.DOMAIN][CONF_PASSWORD] = MOCK_PASSWORD

await async_setup_component(hass, ip.DOMAIN, valid_config_auth)
assert hass.states.get(VALID_ENTITY_ID)


async def test_process_image(hass, mock_image):
"""Test processing of an image."""
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
Expand Down Expand Up @@ -183,16 +196,6 @@ async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file):
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
assert hass.states.get(VALID_ENTITY_ID)

teach_events = []

@callback
def mock_teach_event(event):
"""Mock event."""
teach_events.append(event)

hass.bus.async_listen(
'image_processing.teach_classifier', mock_teach_event)

# Patch out 'is_allowed_path' as the mock files aren't allowed
hass.config.is_allowed_path = Mock(return_value=True)

Expand All @@ -206,13 +209,6 @@ def mock_teach_event(event):
ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data)
await hass.async_block_till_done()

assert len(teach_events) == 1
assert teach_events[0].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
assert teach_events[0].data[ATTR_NAME] == MOCK_NAME
assert teach_events[0].data[fb.FILE_PATH] == MOCK_FILE_PATH
assert teach_events[0].data['success']
assert not teach_events[0].data['message']

# Now test the failed teaching.
with requests_mock.Mocker() as mock_req:
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
Expand All @@ -225,13 +221,6 @@ def mock_teach_event(event):
service_data=data)
await hass.async_block_till_done()

assert len(teach_events) == 2
assert teach_events[1].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
assert teach_events[1].data[ATTR_NAME] == MOCK_NAME
assert teach_events[1].data[fb.FILE_PATH] == MOCK_FILE_PATH
assert not teach_events[1].data['success']
assert teach_events[1].data['message'] == MOCK_ERROR


async def test_setup_platform_with_name(hass):
"""Setup platform with one entity and a name."""
Expand Down