Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 6 additions & 44 deletions homeassistant/components/fan/device_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,26 @@

import voluptuous as vol

from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_TYPE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.components.device_automation import toggle_entity
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry
import homeassistant.helpers.config_validation as cv

from . import DOMAIN

ACTION_TYPES = {"turn_on", "turn_off"}

ACTION_SCHEMA = cv.DEVICE_ACTION_BASE_SCHEMA.extend(
{
vol.Required(CONF_TYPE): vol.In(ACTION_TYPES),
vol.Required(CONF_ENTITY_ID): cv.entity_domain(DOMAIN),
}
)
ACTION_SCHEMA = toggle_entity.ACTION_SCHEMA.extend({vol.Required(CONF_DOMAIN): DOMAIN})


async def async_get_actions(
hass: HomeAssistant, device_id: str
) -> list[dict[str, str]]:
"""List device actions for Fan devices."""
registry = await entity_registry.async_get_registry(hass)
actions = []

# Get all the integrations entities for this device
for entry in entity_registry.async_entries_for_device(registry, device_id):
if entry.domain != DOMAIN:
continue

base_action = {
CONF_DEVICE_ID: device_id,
CONF_DOMAIN: DOMAIN,
CONF_ENTITY_ID: entry.entity_id,
}
actions += [{**base_action, CONF_TYPE: action} for action in ACTION_TYPES]

return actions
return await toggle_entity.async_get_actions(hass, device_id, DOMAIN)


async def async_call_action_from_config(
hass: HomeAssistant, config: dict, variables: dict, context: Context | None
) -> None:
"""Execute a device action."""
service_data = {ATTR_ENTITY_ID: config[CONF_ENTITY_ID]}

if config[CONF_TYPE] == "turn_on":
service = SERVICE_TURN_ON
elif config[CONF_TYPE] == "turn_off":
service = SERVICE_TURN_OFF

await hass.services.async_call(
DOMAIN, service, service_data, blocking=True, context=context
await toggle_entity.async_call_action_from_config(
hass, config, variables, context, DOMAIN
)
25 changes: 23 additions & 2 deletions tests/components/fan/test_device_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def test_get_actions(hass, device_reg, entity_reg):
"entity_id": f"{DOMAIN}.test_5678",
"metadata": {"secondary": False},
}
for action in ["turn_on", "turn_off"]
for action in ["turn_on", "turn_off", "toggle"]
]
actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, device_entry.id
Expand Down Expand Up @@ -98,7 +98,7 @@ async def test_get_actions_hidden_auxiliary(
"entity_id": f"{DOMAIN}.test_5678",
"metadata": {"secondary": True},
}
for action in ["turn_on", "turn_off"]
for action in ["turn_on", "turn_off", "toggle"]
]
actions = await async_get_device_automations(
hass, DeviceAutomationType.ACTION, device_entry.id
Expand Down Expand Up @@ -137,19 +137,40 @@ async def test_action(hass):
"type": "turn_on",
},
},
{
"trigger": {
"platform": "event",
"event_type": "test_event_toggle",
},
"action": {
"domain": DOMAIN,
"device_id": "abcdefgh",
"entity_id": "fan.entity",
"type": "toggle",
},
},
]
},
)

turn_off_calls = async_mock_service(hass, "fan", "turn_off")
turn_on_calls = async_mock_service(hass, "fan", "turn_on")
toggle_calls = async_mock_service(hass, "fan", "toggle")

hass.bus.async_fire("test_event_turn_off")
await hass.async_block_till_done()
assert len(turn_off_calls) == 1
assert len(turn_on_calls) == 0
assert len(toggle_calls) == 0

hass.bus.async_fire("test_event_turn_on")
await hass.async_block_till_done()
assert len(turn_off_calls) == 1
assert len(turn_on_calls) == 1
assert len(toggle_calls) == 0

hass.bus.async_fire("test_event_toggle")
await hass.async_block_till_done()
assert len(turn_off_calls) == 1
assert len(turn_on_calls) == 1
assert len(toggle_calls) == 1