diff --git a/pytradfri/api/aiocoap_api.py b/pytradfri/api/aiocoap_api.py index 6fa1ef81..08203843 100644 --- a/pytradfri/api/aiocoap_api.py +++ b/pytradfri/api/aiocoap_api.py @@ -3,12 +3,14 @@ import json import logging import socket +import time from aiocoap import Message, Context from aiocoap.error import RequestTimedOut, Error, ConstructionRenderableError from aiocoap.numbers.codes import Code from aiocoap.transports import tinydtls +from ..const import OBSERVATION_SLEEP_TIME, OBSERVATION_TIMEOUT from ..error import ClientError, ServerError, RequestTimeout from ..gateway import Gateway @@ -30,12 +32,16 @@ def _get_psk(self, host, port): class APIFactory: + last_changed = time.time() + def __init__(self, host, psk_id='pytradfri', psk=None, loop=None): self._psk = psk self._host = host self._psk_id = psk_id self._loop = loop - self._observations_err_callbacks = [] + self._is_checking = False + self._is_resetting = False + self._observations = [] self._protocol = None if self._loop is None: @@ -78,9 +84,10 @@ async def _reset_protocol(self, exc=None): await protocol.shutdown() self._protocol = None # Let any observers know the protocol has been shutdown. - for ob_error in self._observations_err_callbacks: - ob_error(exc) - self._observations_err_callbacks.clear() + while self._observations: + ob = self._observations.pop() + ob.cancel() + del ob async def shutdown(self, exc=None): """Shutdown the API events. @@ -140,17 +147,23 @@ async def _execute(self, api_command): api_method = Code.FETCH elif method == 'patch': api_method = Code.PATCH + elif method is None: + return msg = Message(code=api_method, uri=url, **kwargs) _, res = await self._get_response(msg) api_command.result = _process_output(res, parse_json) + self._loop.create_task(self._remove_timedout_observations()) return api_command.result async def request(self, api_commands): """Make a request.""" + if not api_commands: + return None + if not isinstance(api_commands, list): result = await self._execute(api_commands) return result @@ -175,6 +188,7 @@ async def _observe(self, api_command): def success_callback(res): api_command.result = _process_output(res) + APIFactory.update_last_changed() def error_callback(ex): err_callback(ex) @@ -182,7 +196,41 @@ def error_callback(ex): ob = pr.observation ob.register_callback(success_callback) ob.register_errback(error_callback) - self._observations_err_callbacks.append(ob.error) + self._observations.append(ob) + + async def _remove_timedout_observations(self): + """ + Removes dead observations from the API. An observation is considered + dead when a timeout (defined in const) is reached. + """ + if self._is_checking: + _LOGGER.debug("Already checking for observations...") + return + + self._is_checking = True + current_time = time.time() + await asyncio.sleep(OBSERVATION_SLEEP_TIME, loop=self._loop) + + if (current_time - APIFactory.get_last_changed()) > \ + (OBSERVATION_TIMEOUT + OBSERVATION_SLEEP_TIME): + _LOGGER.warning('Resetting Tradfri observations...') + + if self._is_resetting: + return + + self._is_resetting = True + + while self._observations: + ob = self._observations.pop() + for c in ob.errbacks: + c(None) + ob.cancel() + del ob + + APIFactory.update_last_changed() + self._is_resetting = False + + self._is_checking = False async def generate_psk(self, security_key): """Generate and set a psk from the security key.""" @@ -203,6 +251,14 @@ async def generate_psk(self, security_key): return self._psk + @classmethod + def update_last_changed(cls): + cls.last_changed = time.time() + + @classmethod + def get_last_changed(cls): + return cls.last_changed + def _process_output(res, parse_json=True): """Process output.""" diff --git a/pytradfri/const.py b/pytradfri/const.py index a26bae18..d9780627 100644 --- a/pytradfri/const.py +++ b/pytradfri/const.py @@ -129,3 +129,6 @@ SUPPORT_HEX_COLOR = 4 SUPPORT_RGB_COLOR = 8 SUPPORT_XY_COLOR = 16 + +OBSERVATION_SLEEP_TIME = 5 +OBSERVATION_TIMEOUT = 10 diff --git a/pytradfri/device.py b/pytradfri/device.py index b76b3f4d..d42e95f5 100644 --- a/pytradfri/device.py +++ b/pytradfri/device.py @@ -299,6 +299,21 @@ def _value_validate(self, value, rnge, identifier="Given"): raise ValueError('%s value must be between %d and %d.' % (identifier, rnge[0], rnge[1])) + def _filter_duplicates(self, values=None): + """ + Removes duplicate state changes from the input object. + """ + if values is None: + return False + + commands = {} + for k, v in values.items(): + if k == ATTR_DEVICE_STATE or k not in self.raw[0] or \ + (k in self.raw[0] and self.raw[0][k] != v): + commands[k] = v + + return len(commands) > 0 + def set_values(self, values, *, index=0): """ Set values on light control. @@ -307,7 +322,11 @@ def set_values(self, values, *, index=0): assert len(self.raw) == 1, \ 'Only devices with 1 light supported' - return Command('put', self._device.path, { + method = None + if self._filter_duplicates(values): + method = 'put' + + return Command(method, self._device.path, { ATTR_LIGHT_CONTROL: [ values ] @@ -408,6 +427,21 @@ def set_state(self, state, *, index=0): ATTR_DEVICE_STATE: int(state) }, index=index) + def _filter_duplicates(self, values=None): + """ + Removes duplicate state changes from the input object. + """ + if values is None: + return False + + commands = {} + for k, v in values.items(): + if k not in self.raw[0] or \ + (k in self.raw[0] and self.raw[0][k] != v): + commands[k] = v + + return len(commands) > 0 + def set_values(self, values, *, index=0): """ Set values on socket control. @@ -416,7 +450,11 @@ def set_values(self, values, *, index=0): assert len(self.raw) == 1, \ 'Only devices with 1 socket supported' - return Command('put', self._device.path, { + method = None + if self._filter_duplicates(values): + method = 'put' + + return Command(method, self._device.path, { ATTR_SWITCH_PLUG: [ values ]