-
-
Notifications
You must be signed in to change notification settings - Fork 37.2k
Add facebox auth #15439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add facebox auth #15439
Changes from all commits
e50ce87
61105e4
e6cf59e
13f70cd
53a8846
ffe1c5c
259bb13
a5d981b
c2fa465
788a8e5
ccb7247
a00389b
6aaa88a
e07af2e
08d66c9
42529d7
c0298b6
72f08fd
df5e05d
509a4b5
2cfc529
7006994
6154103
f673150
136e7c8
303ff24
871c272
ad5008c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,25 +17,29 @@ | |
| from homeassistant.components.image_processing import ( | ||
| PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE, | ||
| CONF_ENTITY_ID, CONF_NAME, DOMAIN) | ||
| from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT) | ||
| from homeassistant.const import ( | ||
| CONF_IP_ADDRESS, CONF_PORT, CONF_PASSWORD, CONF_USERNAME, | ||
| HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED) | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
| ATTR_BOUNDING_BOX = 'bounding_box' | ||
| ATTR_CLASSIFIER = 'classifier' | ||
| ATTR_IMAGE_ID = 'image_id' | ||
| ATTR_ID = 'id' | ||
| ATTR_MATCHED = 'matched' | ||
| FACEBOX_NAME = 'name' | ||
| CLASSIFIER = 'facebox' | ||
| DATA_FACEBOX = 'facebox_classifiers' | ||
| EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier' | ||
| FILE_PATH = 'file_path' | ||
| SERVICE_TEACH_FACE = 'facebox_teach_face' | ||
| TIMEOUT = 9 | ||
|
|
||
|
|
||
| PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ | ||
| vol.Required(CONF_IP_ADDRESS): cv.string, | ||
| vol.Required(CONF_PORT): cv.port, | ||
| vol.Optional(CONF_USERNAME): cv.string, | ||
| vol.Optional(CONF_PASSWORD): cv.string, | ||
| }) | ||
|
|
||
| SERVICE_TEACH_SCHEMA = vol.Schema({ | ||
|
|
@@ -45,6 +49,26 @@ | |
| }) | ||
|
|
||
|
|
||
| def check_box_health(url, username, password): | ||
| """Check the health of the classifier and return its id if healthy.""" | ||
| kwargs = {} | ||
| if username: | ||
| kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) | ||
| try: | ||
| response = requests.get( | ||
| url, | ||
| **kwargs | ||
| ) | ||
| if response.status_code == HTTP_UNAUTHORIZED: | ||
| _LOGGER.error("AuthenticationError on %s", CLASSIFIER) | ||
| return None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should fail setup, ie not add the entity and return from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you advise that the auth should be checked before instantiaiton of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, early is good. |
||
| if response.status_code == HTTP_OK: | ||
| return response.json()['hostname'] | ||
| except requests.exceptions.ConnectionError: | ||
| _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) | ||
| return None | ||
|
|
||
|
|
||
| def encode_image(image): | ||
| """base64 encode an image stream.""" | ||
| base64_img = base64.b64encode(image).decode('ascii') | ||
|
|
@@ -63,10 +87,10 @@ def parse_faces(api_faces): | |
| for entry in api_faces: | ||
| face = {} | ||
| if entry['matched']: # This data is only in matched faces. | ||
| face[ATTR_NAME] = entry['name'] | ||
| face[FACEBOX_NAME] = entry['name'] | ||
| face[ATTR_IMAGE_ID] = entry['id'] | ||
| else: # Lets be explicit. | ||
| face[ATTR_NAME] = None | ||
| face[FACEBOX_NAME] = None | ||
| face[ATTR_IMAGE_ID] = None | ||
| face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2) | ||
| face[ATTR_MATCHED] = entry['matched'] | ||
|
|
@@ -75,17 +99,46 @@ def parse_faces(api_faces): | |
| return known_faces | ||
|
|
||
|
|
||
| def post_image(url, image): | ||
| def post_image(url, image, username, password): | ||
| """Post an image to the classifier.""" | ||
| kwargs = {} | ||
| if username: | ||
| kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) | ||
| try: | ||
| response = requests.post( | ||
| url, | ||
| json={"base64": encode_image(image)}, | ||
| timeout=TIMEOUT | ||
| **kwargs | ||
| ) | ||
| if response.status_code == HTTP_UNAUTHORIZED: | ||
| _LOGGER.error("AuthenticationError on %s", CLASSIFIER) | ||
| return None | ||
| return response | ||
| except requests.exceptions.ConnectionError: | ||
| _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) | ||
| return None | ||
|
|
||
|
|
||
| def teach_file(url, name, file_path, username, password): | ||
| """Teach the classifier a name associated with a file.""" | ||
| kwargs = {} | ||
| if username: | ||
| kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password) | ||
| try: | ||
| with open(file_path, 'rb') as open_file: | ||
| response = requests.post( | ||
| url, | ||
| data={FACEBOX_NAME: name, ATTR_ID: file_path}, | ||
| files={'file': open_file}, | ||
| **kwargs | ||
| ) | ||
| if response.status_code == HTTP_UNAUTHORIZED: | ||
| _LOGGER.error("AuthenticationError on %s", CLASSIFIER) | ||
| elif response.status_code == HTTP_BAD_REQUEST: | ||
| _LOGGER.error("%s teaching of file %s failed with message:%s", | ||
| CLASSIFIER, file_path, response.text) | ||
| except requests.exceptions.ConnectionError: | ||
| _LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER) | ||
|
|
||
|
|
||
| def valid_file_path(file_path): | ||
|
|
@@ -104,13 +157,20 @@ def setup_platform(hass, config, add_devices, discovery_info=None): | |
| if DATA_FACEBOX not in hass.data: | ||
| hass.data[DATA_FACEBOX] = [] | ||
|
|
||
| ip_address = config[CONF_IP_ADDRESS] | ||
| port = config[CONF_PORT] | ||
| username = config.get(CONF_USERNAME) | ||
| password = config.get(CONF_PASSWORD) | ||
| url_health = "http://{}:{}/healthz".format(ip_address, port) | ||
| hostname = check_box_health(url_health, username, password) | ||
| if hostname is None: | ||
| return | ||
|
|
||
| entities = [] | ||
| for camera in config[CONF_SOURCE]: | ||
| facebox = FaceClassifyEntity( | ||
| config[CONF_IP_ADDRESS], | ||
| config[CONF_PORT], | ||
| camera[CONF_ENTITY_ID], | ||
| camera.get(CONF_NAME)) | ||
| ip_address, port, username, password, hostname, | ||
| camera[CONF_ENTITY_ID], camera.get(CONF_NAME)) | ||
| entities.append(facebox) | ||
| hass.data[DATA_FACEBOX].append(facebox) | ||
| add_devices(entities) | ||
|
|
@@ -129,33 +189,37 @@ def service_handle(service): | |
| classifier.teach(name, file_path) | ||
|
|
||
| hass.services.register( | ||
| DOMAIN, | ||
| SERVICE_TEACH_FACE, | ||
| service_handle, | ||
| DOMAIN, SERVICE_TEACH_FACE, service_handle, | ||
| schema=SERVICE_TEACH_SCHEMA) | ||
|
|
||
|
|
||
| class FaceClassifyEntity(ImageProcessingFaceEntity): | ||
| """Perform a face classification.""" | ||
|
|
||
| def __init__(self, ip, port, camera_entity, name=None): | ||
| def __init__(self, ip_address, port, username, password, hostname, | ||
| camera_entity, name=None): | ||
| """Init with the API key and model id.""" | ||
| super().__init__() | ||
| self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER) | ||
| self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER) | ||
| self._url_check = "http://{}:{}/{}/check".format( | ||
| ip_address, port, CLASSIFIER) | ||
| self._url_teach = "http://{}:{}/{}/teach".format( | ||
| ip_address, port, CLASSIFIER) | ||
| self._username = username | ||
| self._password = password | ||
| self._hostname = hostname | ||
| self._camera = camera_entity | ||
| if name: | ||
| self._name = name | ||
| else: | ||
| camera_name = split_entity_id(camera_entity)[1] | ||
| self._name = "{} {}".format( | ||
| CLASSIFIER, camera_name) | ||
| self._name = "{} {}".format(CLASSIFIER, camera_name) | ||
| self._matched = {} | ||
|
|
||
| def process_image(self, image): | ||
| """Process an image.""" | ||
| response = post_image(self._url_check, image) | ||
| if response is not None: | ||
| response = post_image( | ||
| self._url_check, image, self._username, self._password) | ||
| if response: | ||
| response_json = response.json() | ||
| if response_json['success']: | ||
| total_faces = response_json['facesCount'] | ||
|
|
@@ -173,34 +237,8 @@ def teach(self, name, file_path): | |
| if (not self.hass.config.is_allowed_path(file_path) | ||
| or not valid_file_path(file_path)): | ||
| return | ||
| with open(file_path, 'rb') as open_file: | ||
| response = requests.post( | ||
| self._url_teach, | ||
| data={ATTR_NAME: name, 'id': file_path}, | ||
| files={'file': open_file}) | ||
|
|
||
| if response.status_code == 200: | ||
| self.hass.bus.fire( | ||
| EVENT_CLASSIFIER_TEACH, { | ||
| ATTR_CLASSIFIER: CLASSIFIER, | ||
| ATTR_NAME: name, | ||
| FILE_PATH: file_path, | ||
| 'success': True, | ||
| 'message': None | ||
| }) | ||
|
|
||
| elif response.status_code == 400: | ||
| _LOGGER.warning( | ||
| "%s teaching of file %s failed with message:%s", | ||
| CLASSIFIER, file_path, response.text) | ||
| self.hass.bus.fire( | ||
| EVENT_CLASSIFIER_TEACH, { | ||
| ATTR_CLASSIFIER: CLASSIFIER, | ||
| ATTR_NAME: name, | ||
| FILE_PATH: file_path, | ||
| 'success': False, | ||
| 'message': response.text | ||
| }) | ||
| teach_file( | ||
| self._url_teach, name, file_path, self._username, self._password) | ||
|
|
||
| @property | ||
| def camera_entity(self): | ||
|
|
@@ -218,4 +256,5 @@ def device_state_attributes(self): | |
| return { | ||
| 'matched_faces': self._matched, | ||
| 'total_matched_faces': len(self._matched), | ||
| 'hostname': self._hostname | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really I see this in
const.py