Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 71 additions & 38 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

from collections.abc import Callable
import datetime as dt
from functools import lru_cache
from functools import lru_cache, partial
import json
from typing import Any, cast

import voluptuous as vol

from homeassistant.auth.models import User
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.const import (
EVENT_STATE_CHANGED,
Expand Down Expand Up @@ -88,6 +89,32 @@ def pong_message(iden: int) -> dict[str, Any]:
return {"id": iden, "type": "pong"}


def _forward_events_check_permissions(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward state changed events to websocket."""
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_event_message(msg_id, event))


def _forward_events_unconditional(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
msg_id: int,
event: Event,
) -> None:
"""Forward events to websocket."""
send_message(messages.cached_event_message(msg_id, event))


@callback
@decorators.websocket_command(
{
Expand All @@ -109,26 +136,18 @@ def handle_subscribe_events(
raise Unauthorized

if event_type == EVENT_STATE_CHANGED:
user = connection.user

@callback
def forward_events(event: Event) -> None:
"""Forward state changed events to websocket."""
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_event_message(msg["id"], event))

forward_events = callback(
partial(
_forward_events_check_permissions,
connection.send_message,
connection.user,
msg["id"],
)
)
else:

@callback
def forward_events(event: Event) -> None:
"""Forward events to websocket."""
connection.send_message(messages.cached_event_message(msg["id"], event))
forward_events = callback(
partial(_forward_events_unconditional, connection.send_message, msg["id"])
)

connection.subscriptions[msg["id"]] = hass.bus.async_listen(
event_type, forward_events, run_immediately=True
Expand Down Expand Up @@ -280,6 +299,27 @@ def _send_handle_get_states_response(
connection.send_message(construct_result_message(msg_id, f"[{joined_states}]"))


def _forward_entity_changes(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str],
user: User,
msg_id: int,
event: Event,
) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
send_message(messages.cached_state_diff_message(msg_id, event))


@callback
@decorators.websocket_command(
{
Expand All @@ -292,29 +332,22 @@ def handle_subscribe_entities(
) -> None:
"""Handle subscribe entities command."""
entity_ids = set(msg.get("entity_ids", []))
user = connection.user

@callback
def forward_entity_changes(event: Event) -> None:
"""Forward entity state changed events to websocket."""
entity_id = event.data["entity_id"]
if entity_ids and entity_id not in entity_ids:
return
# We have to lookup the permissions again because the user might have
# changed since the subscription was created.
permissions = user.permissions
if not permissions.access_all_entities(
POLICY_READ
) and not permissions.check_entity(event.data["entity_id"], POLICY_READ):
return
connection.send_message(messages.cached_state_diff_message(msg["id"], event))

# We must never await between sending the states and listening for
# state changed events or we will introduce a race condition
# where some states are missed
states = _async_get_allowed_states(hass, connection)
connection.subscriptions[msg["id"]] = hass.bus.async_listen(
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
EVENT_STATE_CHANGED,
callback(
partial(
_forward_entity_changes,
connection.send_message,
entity_ids,
connection.user,
msg["id"],
)
),
run_immediately=True,
)
connection.send_result(msg["id"])

Expand Down
14 changes: 14 additions & 0 deletions homeassistant/components/websocket_api/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
class ActiveConnection:
"""Handle an active websocket client connection."""

__slots__ = (
"logger",
"hass",
"send_message",
"user",
"refresh_token_id",
"subscriptions",
"last_id",
"can_coalesce",
"supported_features",
"handlers",
"binary_handlers",
)

def __init__(
self,
logger: WebSocketAdapter,
Expand Down
20 changes: 18 additions & 2 deletions homeassistant/components/websocket_api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
class WebSocketHandler:
"""Handle an active websocket client connection."""

__slots__ = (
"_hass",
"_request",
"_wsock",
"_handle_task",
"_writer_task",
"_closing",
"_authenticated",
"_logger",
"_peak_checker_unsub",
"_connection",
"_message_queue",
"_ready_future",
)

def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection."""
self._hass = hass
Expand Down Expand Up @@ -201,8 +216,9 @@ def _send_message(self, message: str | dict[str, Any] | Callable[[], str]) -> No
return

message_queue.append(message)
if self._ready_future and not self._ready_future.done():
self._ready_future.set_result(None)
ready_future = self._ready_future
if ready_future and not ready_future.done():
ready_future.set_result(None)

peak_checker_active = self._peak_checker_unsub is not None

Expand Down