Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 53 additions & 45 deletions homeassistant/components/shelly/button.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Button for Shelly."""
from __future__ import annotations

from typing import cast
from collections.abc import Callable
from dataclasses import dataclass
from typing import Final, cast

from homeassistant.components.button import ButtonEntity
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ENTITY_CATEGORY_CONFIG
from homeassistant.core import HomeAssistant
Expand All @@ -17,6 +19,44 @@
from .utils import get_block_device_name, get_device_entry_gen, get_rpc_device_name


@dataclass
class ShellyButtonDescriptionMixin:
"""Mixin to describe a Button entity."""

press_action: Callable


@dataclass
class ShellyButtonDescription(ButtonEntityDescription, ShellyButtonDescriptionMixin):
"""Class to describe a Button entity."""


BUTTONS: Final = [
ShellyButtonDescription(
key="ota_update",
name="OTA Update",
icon="mdi:package-up",
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.async_trigger_ota_update(),
),
ShellyButtonDescription(
key="ota_update_beta",
name="OTA Update Beta",
icon="mdi:flask-outline",
entity_registry_enabled_default=False,
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.async_trigger_ota_update(beta=True),
),
ShellyButtonDescription(
key="reboot",
name="Reboot",
icon="mdi:restart",
entity_category=ENTITY_CATEGORY_CONFIG,
press_action=lambda wrapper: wrapper.device.trigger_reboot(),
),
]


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
Expand All @@ -36,66 +76,34 @@ async def async_setup_entry(
wrapper = cast(BlockDeviceWrapper, block_wrapper)

if wrapper is not None:
async_add_entities(
[
ShellyOtaUpdateStableButton(wrapper, config_entry),
ShellyOtaUpdateBetaButton(wrapper, config_entry),
]
)
async_add_entities([ShellyButton(wrapper, button) for button in BUTTONS])


class ShellyOtaUpdateBaseButton(ButtonEntity):
class ShellyButton(ButtonEntity):
"""Defines a Shelly OTA update base button."""

_attr_entity_category = ENTITY_CATEGORY_CONFIG
entity_description: ShellyButtonDescription

def __init__(
self,
wrapper: RpcDeviceWrapper | BlockDeviceWrapper,
entry: ConfigEntry,
name: str,
beta_channel: bool,
icon: str,
description: ShellyButtonDescription,
) -> None:
"""Initialize Shelly OTA update button."""
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, wrapper.mac)}
)
self.entity_description = description
self.wrapper = wrapper

if isinstance(wrapper, RpcDeviceWrapper):
device_name = get_rpc_device_name(wrapper.device)
else:
device_name = get_block_device_name(wrapper.device)

self._attr_name = f"{device_name} {name}"
self._attr_name = f"{device_name} {description.name}"
self._attr_unique_id = slugify(self._attr_name)
self._attr_icon = icon

self.beta_channel = beta_channel
self.entry = entry
self.wrapper = wrapper
self._attr_device_info = DeviceInfo(
connections={(CONNECTION_NETWORK_MAC, wrapper.mac)}
)

async def async_press(self) -> None:
"""Triggers the OTA update service."""
await self.wrapper.async_trigger_ota_update(beta=self.beta_channel)


class ShellyOtaUpdateStableButton(ShellyOtaUpdateBaseButton):
"""Defines a Shelly OTA update stable channel button."""

def __init__(
self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry
) -> None:
"""Initialize Shelly OTA update button."""
super().__init__(wrapper, entry, "OTA Update", False, "mdi:package-up")


class ShellyOtaUpdateBetaButton(ShellyOtaUpdateBaseButton):
"""Defines a Shelly OTA update beta channel button."""

def __init__(
self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry
) -> None:
"""Initialize Shelly OTA update button."""
super().__init__(wrapper, entry, "OTA Update Beta", True, "mdi:flask-outline")
self._attr_entity_registry_enabled_default = False
await self.entity_description.press_action(self.wrapper)
2 changes: 2 additions & 0 deletions tests/components/shelly/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def coap_wrapper(hass):
firmware_version="some fw string",
update=AsyncMock(),
trigger_ota_update=AsyncMock(),
trigger_reboot=AsyncMock(),
initialized=True,
)

Expand Down Expand Up @@ -173,6 +174,7 @@ async def rpc_wrapper(hass):
firmware_version="some fw string",
update=AsyncMock(),
trigger_ota_update=AsyncMock(),
trigger_reboot=AsyncMock(),
initialized=True,
shutdown=AsyncMock(),
)
Expand Down
86 changes: 76 additions & 10 deletions tests/components/shelly/test_button.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for Shelly button platform."""
from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN
from homeassistant.components.button.const import SERVICE_PRESS
from homeassistant.components.shelly.const import DOMAIN
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_registry import async_get
Expand All @@ -10,6 +11,14 @@ async def test_block_button(hass: HomeAssistant, coap_wrapper):
"""Test block device OTA button."""
assert coap_wrapper

entity_registry = async_get(hass)
entity_registry.async_get_or_create(
BUTTON_DOMAIN,
DOMAIN,
"test_name_ota_update_beta",
suggested_object_id="test_name_ota_update_beta",
disabled_by=None,
)
hass.async_create_task(
hass.config_entries.async_forward_entry_setup(coap_wrapper.entry, BUTTON_DOMAIN)
)
Expand All @@ -27,21 +36,54 @@ async def test_block_button(hass: HomeAssistant, coap_wrapper):
blocking=True,
)
await hass.async_block_till_done()
coap_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False)
assert coap_wrapper.device.trigger_ota_update.call_count == 1
coap_wrapper.device.trigger_ota_update.assert_called_with(beta=False)

# beta channel button
entity_registry = async_get(hass)
entry = entity_registry.async_get("button.test_name_ota_update_beta")
state = hass.states.get("button.test_name_ota_update_beta")

assert entry
assert state is None
assert state
assert state.state == STATE_UNKNOWN

await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_ota_update_beta"},
blocking=True,
)
await hass.async_block_till_done()
assert coap_wrapper.device.trigger_ota_update.call_count == 2
coap_wrapper.device.trigger_ota_update.assert_called_with(beta=True)

# reboot button
state = hass.states.get("button.test_name_reboot")

assert state
assert state.state == STATE_UNKNOWN

await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_reboot"},
blocking=True,
)
await hass.async_block_till_done()
assert coap_wrapper.device.trigger_reboot.call_count == 1


async def test_rpc_button(hass: HomeAssistant, rpc_wrapper):
"""Test rpc device OTA button."""
assert rpc_wrapper

entity_registry = async_get(hass)
entity_registry.async_get_or_create(
BUTTON_DOMAIN,
DOMAIN,
"test_name_ota_update_beta",
suggested_object_id="test_name_ota_update_beta",
disabled_by=None,
)

hass.async_create_task(
hass.config_entries.async_forward_entry_setup(rpc_wrapper.entry, BUTTON_DOMAIN)
)
Expand All @@ -59,12 +101,36 @@ async def test_rpc_button(hass: HomeAssistant, rpc_wrapper):
blocking=True,
)
await hass.async_block_till_done()
rpc_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False)
assert rpc_wrapper.device.trigger_ota_update.call_count == 1
rpc_wrapper.device.trigger_ota_update.assert_called_with(beta=False)

# beta channel button
entity_registry = async_get(hass)
entry = entity_registry.async_get("button.test_name_ota_update_beta")
state = hass.states.get("button.test_name_ota_update_beta")

assert entry
assert state is None
assert state
assert state.state == STATE_UNKNOWN

await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_ota_update_beta"},
blocking=True,
)
await hass.async_block_till_done()
assert rpc_wrapper.device.trigger_ota_update.call_count == 2
rpc_wrapper.device.trigger_ota_update.assert_called_with(beta=True)

# reboot button
state = hass.states.get("button.test_name_reboot")

assert state
assert state.state == STATE_UNKNOWN

await hass.services.async_call(
BUTTON_DOMAIN,
SERVICE_PRESS,
{ATTR_ENTITY_ID: "button.test_name_reboot"},
blocking=True,
)
await hass.async_block_till_done()
assert rpc_wrapper.device.trigger_reboot.call_count == 1