Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,6 +1753,17 @@ def has_service(self, domain: str, service: str) -> bool:
"""
return service.lower() in self._services.get(domain.lower(), [])

def supports_response(self, domain: str, service: str) -> SupportsResponse:
"""Return whether or not the service supports response data.

This exists so that callers can return more helpful error messages given
the context. Will return NONE if the service does not exist as there is
other error handling when calling the service if it does not exist.
"""
if not (handler := self._services[domain][service]):
return SupportsResponse.NONE
return handler.supports_response

def register(
self,
domain: str,
Expand Down
23 changes: 21 additions & 2 deletions homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
Event,
HassJob,
HomeAssistant,
SupportsResponse,
callback,
)
from homeassistant.util import slugify
Expand Down Expand Up @@ -661,20 +662,38 @@ async def _async_call_service_step(self):
self._hass, self._action, self._variables
)

# Validate response data paraters. This check ignores services that do
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters

# not exist which will raise an appropriate error in the service call below.
response_variable = self._action.get(CONF_RESPONSE_VARIABLE)
return_response = response_variable is not None
if self._hass.services.has_service(params[CONF_DOMAIN], params[CONF_SERVICE]):
supports_response = self._hass.services.supports_response(
params[CONF_DOMAIN], params[CONF_SERVICE]
)
if supports_response == SupportsResponse.ONLY and not return_response:
raise vol.Invalid(
f"Script requires '{CONF_RESPONSE_VARIABLE}' for response data "
f"for service call {params[CONF_DOMAIN]}.{params[CONF_SERVICE]}"
)
if supports_response == SupportsResponse.NONE and return_response:
raise vol.Invalid(
f"Script does not support '{CONF_RESPONSE_VARIABLE}' for service "
f"'{CONF_RESPONSE_VARIABLE}' which does not support response data."
)

running_script = (
params[CONF_DOMAIN] == "automation"
and params[CONF_SERVICE] == "trigger"
or params[CONF_DOMAIN] in ("python_script", "script")
)
response_variable = self._action.get(CONF_RESPONSE_VARIABLE)
trace_set_result(params=params, running_script=running_script)
response_data = await self._async_run_long_action(
self._hass.async_create_task(
self._hass.services.async_call(
**params,
blocking=True,
context=self._context,
return_response=(response_variable is not None),
return_response=return_response,
)
),
)
Expand Down
51 changes: 49 additions & 2 deletions tests/helpers/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import ConditionError, HomeAssistantError, ServiceNotFound
Expand Down Expand Up @@ -333,7 +334,7 @@ async def test_calling_service_template(hass: HomeAssistant) -> None:
async def test_calling_service_response_data(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test the calling of a service with return values."""
"""Test the calling of a service with response data."""
context = Context()

def mock_service(call: ServiceCall) -> ServiceResponse:
Expand All @@ -342,7 +343,9 @@ def mock_service(call: ServiceCall) -> ServiceResponse:
return {"data": "value-12345"}
return None

hass.services.async_register("test", "script", mock_service, supports_response=True)
hass.services.async_register(
"test", "script", mock_service, supports_response=SupportsResponse.OPTIONAL
)
sequence = cv.SCRIPT_SCHEMA(
[
{
Expand Down Expand Up @@ -404,6 +407,50 @@ def mock_service(call: ServiceCall) -> ServiceResponse:
)


@pytest.mark.parametrize(
("supports_response", "params", "expected_error"),
[
(
SupportsResponse.NONE,
{"response_variable": "foo"},
"does not support 'response_variable'",
),
(SupportsResponse.ONLY, {}, "requires 'response_variable'"),
],
)
async def test_service_response_data_errors(
hass: HomeAssistant,
supports_response: SupportsResponse,
params: dict[str, str],
expected_error: str,
) -> None:
"""Test the calling of a service with response data error cases."""
context = Context()

def mock_service(call: ServiceCall) -> ServiceResponse:
"""Mock service call."""
raise ValueError("Never invoked")

hass.services.async_register(
"test", "script", mock_service, supports_response=supports_response
)

sequence = cv.SCRIPT_SCHEMA(
[
{
"alias": "service step1",
"service": "test.script",
**params,
},
]
)
script_obj = script.Script(hass, sequence, "Test Name", "test_domain")

with pytest.raises(vol.Invalid, match=expected_error):
await script_obj.async_run(context=context)
await hass.async_block_till_done()


async def test_data_template_with_templated_key(hass: HomeAssistant) -> None:
"""Test the calling of a service with a data_template with a templated key."""
context = Context()
Expand Down