Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions homeassistant/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@
# Name
ATTR_NAME = 'name'

# Data for a SERVICE_EXECUTED event
ATTR_SERVICE_CALL_ID = 'service_call_id'

# Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id'

Expand Down
20 changes: 15 additions & 5 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from homeassistant.const import (
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
Expand Down Expand Up @@ -1042,10 +1042,12 @@ async def async_call(self, domain: str, service: str,
This method is a coroutine.
"""
context = context or Context()
call_id = uuid.uuid4().hex
event_data = {
ATTR_DOMAIN: domain.lower(),
ATTR_SERVICE: service.lower(),
ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id,
}

if not blocking:
Expand All @@ -1058,8 +1060,9 @@ async def async_call(self, domain: str, service: str,
@callback
def service_executed(event: Event) -> None:
"""Handle an executed service."""
if event.context == context:
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True)
unsub()

unsub = self._hass.bus.async_listen(
EVENT_SERVICE_EXECUTED, service_executed)
Expand All @@ -1069,14 +1072,16 @@ def service_executed(event: Event) -> None:

done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
success = bool(done)
unsub()
if not success:
unsub()
return success

async def _event_to_service_call(self, event: Event) -> None:
"""Handle the SERVICE_CALLED events from the EventBus."""
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
call_id = event.data.get(ATTR_SERVICE_CALL_ID)

if not self.has_service(domain, service):
if event.origin == EventOrigin.local:
Expand All @@ -1088,12 +1093,17 @@ async def _event_to_service_call(self, event: Event) -> None:

def fire_service_executed() -> None:
"""Fire service executed event."""
if not call_id:
return

data = {ATTR_SERVICE_CALL_ID: call_id}

if (service_handler.is_coroutinefunction or
service_handler.is_callback):
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {},
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data,
EventOrigin.local, event.context)
else:
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {},
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data,
EventOrigin.local, event.context)

try:
Expand Down
29 changes: 27 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from homeassistant.const import (
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED)
EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED,
EVENT_SERVICE_EXECUTED)

from tests.common import get_test_home_assistant
from tests.common import get_test_home_assistant, async_mock_service

PST = pytz.timezone('America/Los_Angeles')

Expand Down Expand Up @@ -969,3 +970,27 @@ def test_track_task_functions(loop):
assert hass._track_task
finally:
yield from hass.async_stop()


async def test_service_executed_with_subservices(hass):
"""Test we block correctly till all services done."""
calls = async_mock_service(hass, 'test', 'inner')

async def handle_outer(call):
"""Handle outer service call."""
calls.append(call)
call1 = hass.services.async_call('test', 'inner', blocking=True,
context=call.context)
call2 = hass.services.async_call('test', 'inner', blocking=True,
context=call.context)
await asyncio.wait([call1, call2])
calls.append(call)

hass.services.async_register('test', 'outer', handle_outer)

await hass.services.async_call('test', 'outer', blocking=True)

assert len(calls) == 4
assert [call.service for call in calls] == [
'outer', 'inner', 'inner', 'outer']
assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0