Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 53 additions & 35 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def async_register_commands(
async_reg(hass, handle_subscribe_trigger)
async_reg(hass, handle_test_condition)
async_reg(hass, handle_unsubscribe_events)
async_reg(hass, handle_validate_config)


def pong_message(iden: int) -> dict[str, Any]:
Expand Down Expand Up @@ -116,7 +117,7 @@ def forward_events(event: Event) -> None:
event_type, forward_events
)

connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])


@callback
Expand All @@ -139,7 +140,7 @@ def forward_bootstrap_integrations(message: dict[str, Any]) -> None:
hass, SIGNAL_BOOTSTRAP_INTEGRATONS, forward_bootstrap_integrations
)

connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])


@callback
Expand All @@ -157,13 +158,9 @@ def handle_unsubscribe_events(

if subscription in connection.subscriptions:
connection.subscriptions.pop(subscription)()
connection.send_message(messages.result_message(msg["id"]))
connection.send_result(msg["id"])
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Subscription not found."
)
)
connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Subscription not found.")


@decorators.websocket_command(
Expand Down Expand Up @@ -196,36 +193,20 @@ async def handle_call_service(
context,
target=target,
)
connection.send_message(
messages.result_message(msg["id"], {"context": context})
)
connection.send_result(msg["id"], {"context": context})
except ServiceNotFound as err:
if err.domain == msg["domain"] and err.service == msg["service"]:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_NOT_FOUND, "Service not found."
)
)
connection.send_error(msg["id"], const.ERR_NOT_FOUND, "Service not found.")
else:
connection.send_message(
messages.error_message(
msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err)
)
)
connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
except vol.Invalid as err:
connection.send_message(
messages.error_message(msg["id"], const.ERR_INVALID_FORMAT, str(err))
)
connection.send_error(msg["id"], const.ERR_INVALID_FORMAT, str(err))
except HomeAssistantError as err:
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
)
connection.send_error(msg["id"], const.ERR_HOME_ASSISTANT_ERROR, str(err))
except Exception as err: # pylint: disable=broad-except
connection.logger.exception(err)
connection.send_message(
messages.error_message(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))
)
connection.send_error(msg["id"], const.ERR_UNKNOWN_ERROR, str(err))


@callback
Expand All @@ -244,7 +225,7 @@ def handle_get_states(
if entity_perm(state.entity_id, "read")
]

connection.send_message(messages.result_message(msg["id"], states))
connection.send_result(msg["id"], states)


@decorators.websocket_command({vol.Required("type"): "get_services"})
Expand All @@ -254,7 +235,7 @@ async def handle_get_services(
) -> None:
"""Handle get services command."""
descriptions = await async_get_all_descriptions(hass)
connection.send_message(messages.result_message(msg["id"], descriptions))
connection.send_result(msg["id"], descriptions)


@callback
Expand All @@ -263,7 +244,7 @@ def handle_get_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get config command."""
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
connection.send_result(msg["id"], hass.config.as_dict())


@decorators.websocket_command({vol.Required("type"): "manifest/list"})
Expand Down Expand Up @@ -417,7 +398,7 @@ def handle_entity_source(
if entity_perm(entity_id, "read")
}

connection.send_message(messages.result_message(msg["id"], sources))
connection.send_result(msg["id"], sources)
return

sources = {}
Expand Down Expand Up @@ -535,7 +516,7 @@ async def handle_execute_script(
context = connection.context(msg)
script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN)
await script_obj.async_run(msg.get("variables"), context=context)
connection.send_message(messages.result_message(msg["id"], {"context": context}))
connection.send_result(msg["id"], {"context": context})


@decorators.websocket_command(
Expand All @@ -555,3 +536,40 @@ async def handle_fire_event(

hass.bus.async_fire(msg["event_type"], msg.get("event_data"), context=context)
connection.send_result(msg["id"], {"context": context})


@decorators.websocket_command(
{
vol.Required("type"): "validate_config",
vol.Optional("trigger"): cv.match_all,
vol.Optional("condition"): cv.match_all,
vol.Optional("action"): cv.match_all,
}
)
@decorators.async_response
async def handle_validate_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle validate config command."""
# Circular dep
# pylint: disable=import-outside-toplevel
from homeassistant.helpers import condition, script, trigger

result = {}

for key, schema, validator in (
("trigger", cv.TRIGGER_SCHEMA, trigger.async_validate_trigger_config),
("condition", cv.CONDITION_SCHEMA, condition.async_validate_condition_config),
("action", cv.SCRIPT_SCHEMA, script.async_validate_actions_config),
):
if key not in msg:
continue

try:
await validator(hass, schema(msg[key])) # type: ignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the ignore error code when ignoring a type error.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except vol.Invalid as err:
result[key] = {"valid": False, "error": str(err)}
else:
result[key] = {"valid": True, "error": None}

connection.send_result(msg["id"], result)
12 changes: 10 additions & 2 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,12 @@ def script_action(value: Any) -> dict:
if not isinstance(value, dict):
raise vol.Invalid("expected dictionary")

return ACTION_TYPE_SCHEMAS[determine_script_action(value)](value)
try:
action = determine_script_action(value)
except ValueError as err:
raise vol.Invalid(str(err))

return ACTION_TYPE_SCHEMAS[action](value)


SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])
Expand Down Expand Up @@ -1444,7 +1449,10 @@ def determine_script_action(action: dict[str, Any]) -> str:
if CONF_VARIABLES in action:
return SCRIPT_ACTION_VARIABLES

return SCRIPT_ACTION_CALL_SERVICE
if CONF_SERVICE in action or CONF_SERVICE_TEMPLATE in action:
return SCRIPT_ACTION_CALL_SERVICE

raise ValueError("Unable to determine action")


ACTION_TYPE_SCHEMAS: dict[str, Callable[[Any], dict]] = {
Expand Down
57 changes: 57 additions & 0 deletions tests/components/websocket_api/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,3 +1286,60 @@ async def test_integration_setup_info(hass, websocket_client, hass_admin_user):
{"domain": "august", "seconds": 12.5},
{"domain": "isy994", "seconds": 12.8},
]


@pytest.mark.parametrize(
"key,config",
(
("trigger", {"platform": "event", "event_type": "hello"}),
(
"condition",
{"condition": "state", "entity_id": "hello.world", "state": "paulus"},
),
("action", {"service": "domain_test.test_service"}),
),
)
async def test_validate_config_works(websocket_client, key, config):
"""Test config validation."""
await websocket_client.send_json({"id": 7, "type": "validate_config", key: config})

msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {key: {"valid": True, "error": None}}


@pytest.mark.parametrize(
"key,config,error",
(
(
"trigger",
{"platform": "non_existing", "event_type": "hello"},
"Invalid platform 'non_existing' specified",
),
(
"condition",
{
"condition": "non_existing",
"entity_id": "hello.world",
"state": "paulus",
},
"Unexpected value for condition: 'non_existing'. Expected and, device, not, numeric_state, or, state, sun, template, time, trigger, zone",
),
(
"action",
{"non_existing": "domain_test.test_service"},
"Unable to determine action @ data[0]",
),
),
)
async def test_validate_config_invalid(websocket_client, key, config, error):
"""Test config validation."""
await websocket_client.send_json({"id": 7, "type": "validate_config", key: config})

msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {key: {"valid": False, "error": error}}