diff --git a/homeassistant/components/shelly/__init__.py b/homeassistant/components/shelly/__init__.py index be87e2556eb80..eb47821b88e64 100644 --- a/homeassistant/components/shelly/__init__.py +++ b/homeassistant/components/shelly/__init__.py @@ -8,15 +8,17 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( + ATTR_AREA_ID, ATTR_DEVICE_ID, CONF_HOST, CONF_PASSWORD, CONF_USERNAME, EVENT_HOMEASSISTANT_STOP, ) -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, ServiceCall, callback from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import aiohttp_client, device_registry, update_coordinator +from homeassistant.helpers.service import async_extract_referenced_entity_ids from .const import ( AIOSHELLY_DEVICE_TIMEOUT_SEC, @@ -25,6 +27,7 @@ ATTR_DEVICE, BATTERY_DEVICES_WITH_PERMANENT_CONNECTION, COAP, + CONF_SLEEP_PERIOD, DATA_CONFIG_ENTRY, DEVICE, DOMAIN, @@ -33,6 +36,7 @@ POLLING_TIMEOUT_SEC, REST, REST_SENSORS_UPDATE_INTERVAL, + SERVICE_OTA_UPDATE, SLEEP_PERIOD_MULTIPLIER, UPDATE_PERIOD_MULTIPLIER, ) @@ -72,13 +76,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): False, ) - dev_reg = await device_registry.async_get_registry(hass) + dev_reg = device_registry.async_get(hass) identifier = (DOMAIN, entry.unique_id) device_entry = dev_reg.async_get_device(identifiers={identifier}, connections=set()) if device_entry and entry.entry_id not in device_entry.config_entries: device_entry = None - sleep_period = entry.data.get("sleep_period") + sleep_period = entry.data.get(CONF_SLEEP_PERIOD) @callback def _async_device_online(_): @@ -87,7 +91,7 @@ def _async_device_online(_): if sleep_period is None: data = {**entry.data} - data["sleep_period"] = get_device_sleep_period(device.settings) + data[CONF_SLEEP_PERIOD] = get_device_sleep_period(device.settings) data["model"] = device.settings["device"]["type"] hass.config_entries.async_update_entry(entry, data=data) @@ -116,6 +120,8 @@ def _async_device_online(_): _LOGGER.debug("Setting up offline device %s", entry.title) await async_device_setup(hass, entry, device) + await async_services_setup(hass, dev_reg) + return True @@ -130,7 +136,7 @@ async def async_device_setup( platforms = SLEEPING_PLATFORMS - if not entry.data.get("sleep_period"): + if not entry.data.get(CONF_SLEEP_PERIOD): hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][ REST ] = ShellyDeviceRestWrapper(hass, device) @@ -142,13 +148,48 @@ async def async_device_setup( ) +async def async_services_setup( + hass: HomeAssistant, dev_reg: device_registry.DeviceRegistry +): + """Set up services.""" + + async def async_service_ota_update(call: ServiceCall): + """Trigger OTA update.""" + if not (call.data.get(ATTR_DEVICE_ID) or call.data.get(ATTR_AREA_ID)): + _LOGGER.warning("OTA update service: no target selected") + return + + selected_ids = await async_extract_referenced_entity_ids(hass, call) + for device_id in selected_ids.referenced_devices: + device = dev_reg.async_get(device_id) + if DOMAIN not in next(iter(device.identifiers)): + continue + entry_id = next(iter(device.config_entries)) + entry_data = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry_id] + device_wrapper: ShellyDeviceWrapper = entry_data[COAP] + if device_wrapper.is_ota_pending: + _LOGGER.warning( + "There is already an ota update scheduled for device %s", + device.name, + ) + continue + + await device_wrapper.async_trigger_ota_update( + beta=call.data.get("beta"), + url=call.data.get("url"), + force=call.data.get("force"), + ) + + hass.services.async_register(DOMAIN, SERVICE_OTA_UPDATE, async_service_ota_update) + + class ShellyDeviceWrapper(update_coordinator.DataUpdateCoordinator): """Wrapper for a Shelly device with Home Assistant specific functions.""" def __init__(self, hass, entry, device: aioshelly.Device): """Initialize the Shelly device wrapper.""" self.device_id = None - sleep_period = entry.data["sleep_period"] + sleep_period = entry.data[CONF_SLEEP_PERIOD] if sleep_period: update_interval = SLEEP_PERIOD_MULTIPLIER * sleep_period @@ -172,6 +213,8 @@ def __init__(self, hass, entry, device: aioshelly.Device): self._async_device_updates_handler ) self._last_input_events_count = {} + self._ota_update_pending = False + self._ota_update_params = {} hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._handle_ha_stop) @@ -181,6 +224,9 @@ def _async_device_updates_handler(self): if not self.device.initialized: return + if self._ota_update_pending: + self.async_trigger_ota_update() + # Check for input events for block in self.device.blocks: if ( @@ -220,7 +266,7 @@ def _async_device_updates_handler(self): async def _async_update_data(self): """Fetch data.""" - if self.entry.data.get("sleep_period"): + if self.entry.data.get(CONF_SLEEP_PERIOD): # Sleeping device, no point polling it, just mark it unavailable raise update_coordinator.UpdateFailed("Sleeping device did not update") @@ -241,9 +287,14 @@ def mac(self): """Mac address of the device.""" return self.entry.unique_id + @property + def is_ota_pending(self): + """Return if ota update is scheduled for device.""" + return self._ota_update_pending + async def async_setup(self): """Set up the wrapper.""" - dev_reg = await device_registry.async_get_registry(self.hass) + dev_reg = device_registry.async_get(self.hass) sw_version = self.device.settings["fw"] if self.device.initialized else "" entry = dev_reg.async_get_or_create( config_entry_id=self.entry.entry_id, @@ -258,6 +309,81 @@ async def async_setup(self): self.device_id = entry.id self.device.subscribe_updates(self.async_set_updated_data) + async def async_trigger_ota_update(self, beta=False, url=None, force=False): + """Trigger an ota update.""" + if self.entry.data.get(CONF_SLEEP_PERIOD) and not self._ota_update_pending: + self._ota_update_pending = True + self._ota_update_params = { + "beta": beta, + "force": force, + "url": url, + } + _LOGGER.info("OTA update scheduled for sleeping device %s", self.name) + return + + def _reset_pending_ota(): + """Reset OTA update scheduler for sleeping device.""" + if self._ota_update_pending: + _LOGGER.debug( + "Reset OTA update scheduler for sleeping device %s", self.name + ) + self._ota_update_pending = False + self._ota_update_params = {} + + if not self._ota_update_pending: + await self.async_refresh() + else: + beta = self._ota_update_params["beta"] + force = self._ota_update_params["force"] + url = self._ota_update_params["url"] + + update_data = self.device.status["update"] + _LOGGER.debug("OTA update service - update_data: %s", update_data) + + if not update_data["has_update"] and not beta and not url and not force: + _LOGGER.info("No OTA update for %s available", self.name) + _reset_pending_ota() + return + + if beta and not update_data.get("beta_version"): + _LOGGER.info("No beta OTA update for %s available", self.name) + _reset_pending_ota() + return + + if update_data["status"] == "updating": + _LOGGER.warning("OTA update already in progress for %s", self.name) + _reset_pending_ota() + return + + new_version = update_data["new_version"] + if beta: + new_version = update_data["beta_version"] + if url: + new_version = url + + _LOGGER.info( + "Trigger OTA update for device %s from '%s' to '%s'", + self.name, + update_data["old_version"], + new_version, + ) + + resp = None + try: + async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): + resp = await self.device.trigger_ota_update( + beta=beta, + url=url, + ) + except OSError as err: + _LOGGER.exception("Error while trigger ota update: %s", err) + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error while ota update: %s", err) + + _LOGGER.debug("OTA update response: %s", resp) + _reset_pending_ota() + return + def shutdown(self): """Shutdown the wrapper.""" self.device.shutdown() @@ -318,7 +444,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): platforms = SLEEPING_PLATFORMS - if not entry.data.get("sleep_period"): + if not entry.data.get(CONF_SLEEP_PERIOD): hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][REST] = None platforms = PLATFORMS diff --git a/homeassistant/components/shelly/binary_sensor.py b/homeassistant/components/shelly/binary_sensor.py index 385b3b30c36dd..937dccb0cb464 100644 --- a/homeassistant/components/shelly/binary_sensor.py +++ b/homeassistant/components/shelly/binary_sensor.py @@ -13,6 +13,7 @@ BinarySensorEntity, ) +from .const import CONF_SLEEP_PERIOD from .entity import ( BlockAttributeDescription, RestAttributeDescription, @@ -105,7 +106,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): """Set up sensors for device.""" - if config_entry.data["sleep_period"]: + if config_entry.data[CONF_SLEEP_PERIOD]: await async_setup_entry_attribute_entities( hass, config_entry, diff --git a/homeassistant/components/shelly/config_flow.py b/homeassistant/components/shelly/config_flow.py index 73c231086eff5..b9852353ce645 100644 --- a/homeassistant/components/shelly/config_flow.py +++ b/homeassistant/components/shelly/config_flow.py @@ -16,7 +16,7 @@ ) from homeassistant.helpers import aiohttp_client -from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, DOMAIN +from .const import AIOSHELLY_DEVICE_TIMEOUT_SEC, CONF_SLEEP_PERIOD, DOMAIN from .utils import get_coap_context, get_device_sleep_period _LOGGER = logging.getLogger(__name__) @@ -50,7 +50,7 @@ async def validate_input(hass: core.HomeAssistant, host, data): return { "title": device.settings["name"], "hostname": device.settings["device"]["hostname"], - "sleep_period": get_device_sleep_period(device.settings), + CONF_SLEEP_PERIOD: get_device_sleep_period(device.settings), "model": device.settings["device"]["type"], } @@ -97,7 +97,7 @@ async def async_step_user(self, user_input=None): title=device_info["title"] or device_info["hostname"], data={ **user_input, - "sleep_period": device_info["sleep_period"], + CONF_SLEEP_PERIOD: device_info[CONF_SLEEP_PERIOD], "model": device_info["model"], }, ) @@ -128,7 +128,7 @@ async def async_step_credentials(self, user_input=None): data={ **user_input, CONF_HOST: self.host, - "sleep_period": device_info["sleep_period"], + CONF_SLEEP_PERIOD: device_info[CONF_SLEEP_PERIOD], "model": device_info["model"], }, ) @@ -180,8 +180,8 @@ async def async_step_confirm_discovery(self, user_input=None): return self.async_create_entry( title=self.device_info["title"] or self.device_info["hostname"], data={ - "host": self.host, - "sleep_period": self.device_info["sleep_period"], + CONF_HOST: self.host, + CONF_SLEEP_PERIOD: self.device_info[CONF_SLEEP_PERIOD], "model": self.device_info["model"], }, ) diff --git a/homeassistant/components/shelly/const.py b/homeassistant/components/shelly/const.py index 4fda656e7b446..e806cc84ddc45 100644 --- a/homeassistant/components/shelly/const.py +++ b/homeassistant/components/shelly/const.py @@ -5,6 +5,8 @@ DEVICE = "device" DOMAIN = "shelly" REST = "rest" +SERVICE_OTA_UPDATE = "ota_update" +SERVICES = [SERVICE_OTA_UPDATE] # Used in "_async_update_data" as timeout for polling data from devices. POLLING_TIMEOUT_SEC = 18 @@ -42,6 +44,7 @@ ATTR_CLICK_TYPE = "click_type" ATTR_CHANNEL = "channel" ATTR_DEVICE = "device" +CONF_SLEEP_PERIOD = "sleep_period" CONF_SUBTYPE = "subtype" BASIC_INPUTS_EVENTS_TYPES = { diff --git a/homeassistant/components/shelly/sensor.py b/homeassistant/components/shelly/sensor.py index b6d3bc2dbff3e..4484bd99f2613 100644 --- a/homeassistant/components/shelly/sensor.py +++ b/homeassistant/components/shelly/sensor.py @@ -13,7 +13,7 @@ VOLT, ) -from .const import SHAIR_MAX_WORK_HOURS +from .const import CONF_SLEEP_PERIOD, SHAIR_MAX_WORK_HOURS from .entity import ( BlockAttributeDescription, RestAttributeDescription, @@ -194,7 +194,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): """Set up sensors for device.""" - if config_entry.data["sleep_period"]: + if config_entry.data[CONF_SLEEP_PERIOD]: await async_setup_entry_attribute_entities( hass, config_entry, async_add_entities, SENSORS, ShellySleepingSensor ) diff --git a/homeassistant/components/shelly/services.yaml b/homeassistant/components/shelly/services.yaml new file mode 100644 index 0000000000000..df1d89be04b9e --- /dev/null +++ b/homeassistant/components/shelly/services.yaml @@ -0,0 +1,35 @@ +# shelly service descriptions. + +ota_update: + name: OTA Update + description: Trigger an over-the-air (OTA) update. + target: + device: + integration: shelly + entity: + integration: none + fields: + url: + name: Firmware url + description: Run firmware update from specified URL + required: false + example: http://api.shelly.cloud/firmware/rc/SHPLG-S.zip + advanced: true + selector: + text: + beta: + name: Beta + description: Run firmware update from beta URL (if available) + required: false + default: false + example: true + selector: + boolean: + force: + name: Force + description: Force firmware update + required: false + default: false + example: true + selector: + boolean: diff --git a/tests/components/shelly/test_init.py b/tests/components/shelly/test_init.py new file mode 100644 index 0000000000000..6445a067a12a7 --- /dev/null +++ b/tests/components/shelly/test_init.py @@ -0,0 +1,10 @@ +"""Tests for the Shelly integration init.""" +from homeassistant.components.shelly import async_services_setup +from homeassistant.components.shelly.const import DOMAIN, SERVICES + + +async def test_services_registered(hass, device_reg): + """Test if all services are registered.""" + await async_services_setup(hass, device_reg) + for service in SERVICES: + assert hass.services.has_service(DOMAIN, service)