diff --git a/.github/workflows/builder.yml b/.github/workflows/builder.yml index 8c8c9978aa2..624ca0d89f2 100644 --- a/.github/workflows/builder.yml +++ b/.github/workflows/builder.yml @@ -284,9 +284,10 @@ jobs: --privileged \ --security-opt seccomp=unconfined \ --security-opt apparmor=unconfined \ - -v /run/docker.sock:/run/docker.sock \ - -v /run/dbus:/run/dbus \ - -v /tmp/supervisor/data:/data \ + -v /run/docker.sock:/run/docker.sock:rw \ + -v /run/dbus:/run/dbus:ro \ + -v /run/supervisor:/run/os:rw \ + -v /tmp/supervisor/data:/data:rw,slave \ -v /etc/machine-id:/etc/machine-id:ro \ -e SUPERVISOR_SHARE="/tmp/supervisor/data" \ -e SUPERVISOR_NAME=hassio_supervisor \ diff --git a/supervisor/api/proxy.py b/supervisor/api/proxy.py index b5a04214d62..a505b0196f2 100644 --- a/supervisor/api/proxy.py +++ b/supervisor/api/proxy.py @@ -7,7 +7,6 @@ import aiohttp from aiohttp import WSCloseCode, WSMessageTypeError, web -from aiohttp.client_exceptions import ClientConnectorError from aiohttp.client_ws import ClientWebSocketResponse from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE from aiohttp.http_websocket import WSMsgType @@ -179,57 +178,16 @@ async def api(self, request: web.Request): async def _websocket_client(self) -> ClientWebSocketResponse: """Initialize a WebSocket API connection.""" - url = f"{self.sys_homeassistant.api_url}/api/websocket" - try: - client = await self.sys_websession.ws_connect( - url, heartbeat=30, ssl=False, max_msg_size=MAX_MESSAGE_SIZE_FROM_CORE - ) - - # Handle authentication - data = await client.receive_json() - - if data.get("type") == "auth_ok": - return client - - if data.get("type") != "auth_required": - # Invalid protocol - raise APIError( - f"Got unexpected response from Home Assistant WebSocket: {data}", - _LOGGER.error, - ) - - # Auth session - await self.sys_homeassistant.api.ensure_access_token() - await client.send_json( - { - "type": "auth", - "access_token": self.sys_homeassistant.api.access_token, - }, - dumps=json_dumps, + ws_client = await self.sys_homeassistant.api.connect_websocket( + max_msg_size=MAX_MESSAGE_SIZE_FROM_CORE ) - - data = await client.receive_json() - - if data.get("type") == "auth_ok": - return client - - # Renew the Token is invalid - if ( - data.get("type") == "invalid_auth" - and self.sys_homeassistant.refresh_token - ): - self.sys_homeassistant.api.access_token = None - return await self._websocket_client() - - raise HomeAssistantAuthError() - - except (RuntimeError, ValueError, TypeError, ClientConnectorError) as err: - _LOGGER.error("Client error on WebSocket API %s.", err) - except HomeAssistantAuthError: - _LOGGER.error("Failed authentication to Home Assistant WebSocket") - - raise APIError() + return ws_client.client + except HomeAssistantAPIError as err: + raise APIError( + f"Error connecting to Home Assistant WebSocket: {err}", + _LOGGER.error, + ) from err async def _proxy_message( self, diff --git a/supervisor/const.py b/supervisor/const.py index 996f32458bd..8a36356e95b 100644 --- a/supervisor/const.py +++ b/supervisor/const.py @@ -39,9 +39,10 @@ FILE_SUFFIX_CONFIGURATION = [".yaml", ".yml", ".json"] MACHINE_ID = Path("/etc/machine-id") +RUN_SUPERVISOR_STATE = Path("/run/supervisor") +SOCKET_CORE = Path("/run/os/core.sock") SOCKET_DBUS = Path("/run/dbus/system_bus_socket") SOCKET_DOCKER = Path("/run/docker.sock") -RUN_SUPERVISOR_STATE = Path("/run/supervisor") SYSTEMD_JOURNAL_PERSISTENT = Path("/var/log/journal") SYSTEMD_JOURNAL_VOLATILE = Path("/run/log/journal") diff --git a/supervisor/core.py b/supervisor/core.py index 1594e679c8f..9cfec29fca1 100644 --- a/supervisor/core.py +++ b/supervisor/core.py @@ -338,6 +338,7 @@ async def stop(self) -> None: self.sys_create_task(coro) for coro in ( self.sys_websession.close(), + self.sys_homeassistant.api.close(), self.sys_ingress.unload(), self.sys_hardware.unload(), self.sys_dbus.unload(), diff --git a/supervisor/docker/const.py b/supervisor/docker/const.py index e043cea1c1a..711bb8c0605 100644 --- a/supervisor/docker/const.py +++ b/supervisor/docker/const.py @@ -140,6 +140,7 @@ def to_dict(self) -> dict[str, str | int]: } +ENV_CORE_API_SOCKET = "SUPERVISOR_CORE_API_SOCKET" ENV_DUPLICATE_LOG_FILE = "HA_DUPLICATE_LOG_FILE" ENV_TIME = "TZ" ENV_TOKEN = "SUPERVISOR_TOKEN" @@ -169,6 +170,12 @@ def to_dict(self) -> dict[str, str | int]: target=MACHINE_ID.as_posix(), read_only=True, ) +MOUNT_CORE_RUN = DockerMount( + type=MountType.BIND, + source="/run/supervisor", + target="/run/supervisor", + read_only=False, +) MOUNT_UDEV = DockerMount( type=MountType.BIND, source="/run/udev", target="/run/udev", read_only=True ) diff --git a/supervisor/docker/homeassistant.py b/supervisor/docker/homeassistant.py index c82af5a7b63..734fa4848a6 100644 --- a/supervisor/docker/homeassistant.py +++ b/supervisor/docker/homeassistant.py @@ -13,10 +13,12 @@ from ..jobs.const import JobConcurrency from ..jobs.decorator import Job from .const import ( + ENV_CORE_API_SOCKET, ENV_DUPLICATE_LOG_FILE, ENV_TIME, ENV_TOKEN, ENV_TOKEN_OLD, + MOUNT_CORE_RUN, MOUNT_DBUS, MOUNT_DEV, MOUNT_MACHINE_ID, @@ -162,6 +164,9 @@ def mounts(self) -> list[DockerMount]: if self.sys_machine_id: mounts.append(MOUNT_MACHINE_ID) + if self.sys_homeassistant.api.supports_unix_socket: + mounts.append(MOUNT_CORE_RUN) + return mounts @Job( @@ -180,6 +185,8 @@ async def run(self, *, restore_job_id: str | None = None) -> None: } if restore_job_id: environment[ENV_RESTORE_JOB_ID] = restore_job_id + if self.sys_homeassistant.api.supports_unix_socket: + environment[ENV_CORE_API_SOCKET] = "/run/supervisor/core.sock" if self.sys_homeassistant.duplicate_log_file: environment[ENV_DUPLICATE_LOG_FILE] = "1" await self._run( diff --git a/supervisor/docker/interface.py b/supervisor/docker/interface.py index ef590bfc9ca..6b8c9221773 100644 --- a/supervisor/docker/interface.py +++ b/supervisor/docker/interface.py @@ -115,6 +115,11 @@ def timeout(self) -> int: def name(self) -> str: """Return name of Docker container.""" + @property + def attached(self) -> bool: + """Return True if container/image metadata has been loaded.""" + return self._meta is not None + @property def meta_config(self) -> dict[str, Any]: """Return meta data of configuration for container/image.""" diff --git a/supervisor/homeassistant/api.py b/supervisor/homeassistant/api.py index 3d3c5d7e51a..0fe07516353 100644 --- a/supervisor/homeassistant/api.py +++ b/supervisor/homeassistant/api.py @@ -13,13 +13,20 @@ from awesomeversion import AwesomeVersion from multidict import MultiMapping +from ..const import SOCKET_CORE from ..coresys import CoreSys, CoreSysAttributes +from ..docker.const import ENV_CORE_API_SOCKET, ContainerState +from ..docker.monitor import DockerContainerStateEvent from ..exceptions import HomeAssistantAPIError, HomeAssistantAuthError from ..utils import version_is_new_enough from .const import LANDINGPAGE +from .websocket import WSClient _LOGGER: logging.Logger = logging.getLogger(__name__) +CORE_UNIX_SOCKET_MIN_VERSION: AwesomeVersion = AwesomeVersion( + "2026.4.0.dev202603250907" +) GET_CORE_STATE_MIN_VERSION: AwesomeVersion = AwesomeVersion("2023.8.0.dev20230720") @@ -39,11 +46,101 @@ def __init__(self, coresys: CoreSys): self.coresys: CoreSys = coresys # We don't persist access tokens. Instead we fetch new ones when needed - self.access_token: str | None = None + self._access_token: str | None = None self._access_token_expires: datetime | None = None self._token_lock: asyncio.Lock = asyncio.Lock() + self._unix_session: aiohttp.ClientSession | None = None + self._core_connected: bool = False - async def ensure_access_token(self) -> None: + @property + def supports_unix_socket(self) -> bool: + """Return True if the installed Core version supports Unix socket communication. + + Used to decide whether to configure the env var when starting Core. + """ + return ( + self.sys_homeassistant.version is not None + and self.sys_homeassistant.version != LANDINGPAGE + and version_is_new_enough( + self.sys_homeassistant.version, CORE_UNIX_SOCKET_MIN_VERSION + ) + ) + + @property + def use_unix_socket(self) -> bool: + """Return True if the running Core container is configured for Unix socket. + + Checks both version support and that the container was actually started + with the SUPERVISOR_CORE_API_SOCKET env var. This prevents failures + during Supervisor upgrades where Core is still running with a container + started by the old Supervisor. + + Requires container metadata to be available (via attach() or run()). + Callers should ensure the container is running before using this. + """ + if not self.supports_unix_socket: + return False + instance = self.sys_homeassistant.core.instance + if not instance.attached: + raise HomeAssistantAPIError( + "Cannot determine Core connection mode: container metadata not available" + ) + return any( + env.startswith(f"{ENV_CORE_API_SOCKET}=") + for env in instance.meta_config.get("Env", []) + ) + + @property + def session(self) -> aiohttp.ClientSession: + """Return session for Core communication. + + Uses a Unix socket session when the installed Core version supports it, + otherwise falls back to the default TCP websession. If the socket does + not exist yet (e.g. during Core startup), requests will fail with a + connection error handled by the caller. + """ + if not self.use_unix_socket: + return self.sys_websession + + if self._unix_session is None or self._unix_session.closed: + self._unix_session = aiohttp.ClientSession( + connector=aiohttp.UnixConnector(path=str(SOCKET_CORE)) + ) + return self._unix_session + + @property + def api_url(self) -> str: + """Return API base url for internal Supervisor to Core communication.""" + if self.use_unix_socket: + return "http://localhost" + return self.sys_homeassistant.api_url + + @property + def ws_url(self) -> str: + """Return WebSocket url for internal Supervisor to Core communication.""" + if self.use_unix_socket: + return "ws://localhost/api/websocket" + return self.sys_homeassistant.ws_url + + async def container_state_changed(self, event: DockerContainerStateEvent) -> None: + """Process Core container state changes.""" + if event.name != self.sys_homeassistant.core.instance.name: + return + if event.state not in (ContainerState.STOPPED, ContainerState.FAILED): + return + + self._core_connected = False + if self._unix_session and not self._unix_session.closed: + await self._unix_session.close() + self._unix_session = None + + async def close(self) -> None: + """Close the Unix socket session.""" + if self._unix_session and not self._unix_session.closed: + await self._unix_session.close() + self._unix_session = None + + async def _ensure_access_token(self) -> None: """Ensure there is a valid access token. Raises: @@ -55,7 +152,7 @@ async def ensure_access_token(self) -> None: # Fast path check without lock (avoid unnecessary locking # for the majority of calls). if ( - self.access_token + self._access_token and self._access_token_expires and self._access_token_expires > datetime.now(tz=UTC) ): @@ -64,7 +161,7 @@ async def ensure_access_token(self) -> None: async with self._token_lock: # Double-check after acquiring lock (avoid race condition) if ( - self.access_token + self._access_token and self._access_token_expires and self._access_token_expires > datetime.now(tz=UTC) ): @@ -86,11 +183,50 @@ async def ensure_access_token(self) -> None: _LOGGER.info("Updated Home Assistant API token") tokens = await resp.json() - self.access_token = tokens["access_token"] + self._access_token = tokens["access_token"] self._access_token_expires = datetime.now(tz=UTC) + timedelta( seconds=tokens["expires_in"] ) + async def connect_websocket( + self, *, max_msg_size: int = 4 * 1024 * 1024 + ) -> WSClient: + """Connect a WebSocket to Core, handling auth as appropriate. + + For Unix socket connections, no authentication is needed. + For TCP connections, handles token management with one retry + on auth failure. + + Raises: + HomeAssistantAPIError: On connection or auth failure. + + """ + if not await self.sys_homeassistant.core.instance.is_running(): + raise HomeAssistantAPIError("Core container is not running", _LOGGER.debug) + + if self.use_unix_socket: + return await WSClient.connect( + self.session, self.ws_url, max_msg_size=max_msg_size + ) + + for attempt in (1, 2): + try: + await self._ensure_access_token() + assert self._access_token + return await WSClient.connect_with_auth( + self.session, + self.ws_url, + self._access_token, + max_msg_size=max_msg_size, + ) + except HomeAssistantAPIError: + self._access_token = None + if attempt == 2: + raise + + # Unreachable, but satisfies type checker + raise RuntimeError("Unreachable") + @asynccontextmanager async def make_request( self, @@ -103,15 +239,16 @@ async def make_request( params: MultiMapping[str] | None = None, headers: dict[str, str] | None = None, ) -> AsyncIterator[aiohttp.ClientResponse]: - """Async context manager to make authenticated requests to Home Assistant API. + """Async context manager to make requests to Home Assistant Core API. - This context manager handles authentication token management automatically, - including token refresh on 401 responses. It yields the HTTP response - for the caller to handle. + This context manager handles transport and authentication automatically. + For Unix socket connections, requests are made directly without auth. + For TCP connections, it manages access tokens and retries once on 401. + It yields the HTTP response for the caller to handle. Error Handling: - HTTP error status codes (4xx, 5xx) are preserved in the response - - Authentication is handled transparently with one retry on 401 + - Authentication is handled transparently (TCP only) - Network/connection failures raise HomeAssistantAPIError - No logging is performed - callers should handle logging as needed @@ -133,19 +270,22 @@ async def make_request( network errors, timeouts, or connection failures """ - url = f"{self.sys_homeassistant.api_url}/{path}" + if not await self.sys_homeassistant.core.instance.is_running(): + raise HomeAssistantAPIError("Core container is not running", _LOGGER.debug) + + url = f"{self.api_url}/{path}" headers = headers or {} client_timeout = aiohttp.ClientTimeout(total=timeout) - # Passthrough content type if content_type is not None: headers[hdrs.CONTENT_TYPE] = content_type for _ in (1, 2): try: - await self.ensure_access_token() - headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" - async with self.sys_websession.request( + if not self.use_unix_socket: + await self._ensure_access_token() + headers[hdrs.AUTHORIZATION] = f"Bearer {self._access_token}" + async with self.session.request( method, url, data=data, @@ -155,9 +295,8 @@ async def make_request( params=params, ssl=False, ) as resp: - # Access token expired - if resp.status == 401: - self.access_token = None + if resp.status == 401 and not self.use_unix_socket: + self._access_token = None continue yield resp return @@ -184,7 +323,10 @@ async def get_config(self) -> dict[str, Any]: async def get_core_state(self) -> dict[str, Any]: """Return Home Assistant core state.""" - return await self._get_json("api/core/state") + state = await self._get_json("api/core/state") + if state is None or not isinstance(state, dict): + raise HomeAssistantAPIError("No state received from Home Assistant API") + return state async def get_api_state(self) -> APIState | None: """Return state of Home Assistant Core or None.""" @@ -206,14 +348,23 @@ async def get_api_state(self) -> APIState | None: data = await self.get_core_state() else: data = await self.get_config() + + if not self._core_connected: + self._core_connected = True + transport = ( + f"Unix socket {SOCKET_CORE}" + if self.use_unix_socket + else f"TCP {self.sys_homeassistant.api_url}" + ) + _LOGGER.info("Connected to Core via %s", transport) + # Older versions of home assistant does not expose the state - if data: - state = data.get("state", "RUNNING") - # Recorder state was added in HA Core 2024.8 - recorder_state = data.get("recorder_state", {}) - migrating = recorder_state.get("migration_in_progress", False) - live_migration = recorder_state.get("migration_is_live", False) - return APIState(state, migrating and not live_migration) + state = data.get("state", "RUNNING") + # Recorder state was added in HA Core 2024.8 + recorder_state = data.get("recorder_state", {}) + migrating = recorder_state.get("migration_in_progress", False) + live_migration = recorder_state.get("migration_is_live", False) + return APIState(state, migrating and not live_migration) except HomeAssistantAPIError as err: _LOGGER.debug("Can't connect to Home Assistant API: %s", err) diff --git a/supervisor/homeassistant/module.py b/supervisor/homeassistant/module.py index 7ff52c406a2..8df2b71c8fe 100644 --- a/supervisor/homeassistant/module.py +++ b/supervisor/homeassistant/module.py @@ -318,6 +318,10 @@ async def load(self) -> None: ) # Register for events + self.sys_bus.register_event( + BusEvent.DOCKER_CONTAINER_STATE_CHANGE, + self._api.container_state_changed, + ) self.sys_bus.register_event(BusEvent.HARDWARE_NEW_DEVICE, self._hardware_events) self.sys_bus.register_event( BusEvent.HARDWARE_REMOVE_DEVICE, self._hardware_events diff --git a/supervisor/homeassistant/websocket.py b/supervisor/homeassistant/websocket.py index eab4d506116..79cc781143c 100644 --- a/supervisor/homeassistant/websocket.py +++ b/supervisor/homeassistant/websocket.py @@ -3,9 +3,8 @@ from __future__ import annotations import asyncio -from contextlib import suppress import logging -from typing import Any, TypeVar, cast +from typing import Any, TypeVar import aiohttp from aiohttp.http_websocket import WSMsgType @@ -45,14 +44,14 @@ def __init__( ): """Initialise the WS client.""" self.ha_version = ha_version - self._client = client + self.client = client self._message_id: int = 0 self._futures: dict[int, asyncio.Future[T]] = {} # type: ignore @property def connected(self) -> bool: """Return if we're currently connected.""" - return self._client is not None and not self._client.closed + return self.client is not None and not self.client.closed async def close(self) -> None: """Close down the client.""" @@ -62,8 +61,8 @@ async def close(self) -> None: HomeAssistantWSConnectionError("Connection was closed") ) - if not self._client.closed: - await self._client.close() + if not self.client.closed: + await self.client.close() async def async_send_command(self, message: dict[str, Any]) -> T: """Send a websocket message, and return the response.""" @@ -72,7 +71,7 @@ async def async_send_command(self, message: dict[str, Any]) -> T: self._futures[message["id"]] = asyncio.get_running_loop().create_future() _LOGGER.debug("Sending: %s", message) try: - await self._client.send_json(message, dumps=json_dumps) + await self.client.send_json(message, dumps=json_dumps) except ConnectionError as err: raise HomeAssistantWSConnectionError(str(err)) from err @@ -97,7 +96,7 @@ async def start_listener(self) -> None: async def _receive_json(self) -> None: """Receive json.""" - msg = await self._client.receive() + msg = await self.client.receive() _LOGGER.debug("Received: %s", msg) if msg.type == WSMsgType.CLOSE: @@ -139,27 +138,105 @@ async def _receive_json(self) -> None: ) @classmethod - async def connect_with_auth( - cls, session: aiohttp.ClientSession, url: str, token: str - ) -> WSClient: - """Create an authenticated websocket client.""" + async def _ws_connect( + cls, + session: aiohttp.ClientSession, + url: str, + *, + max_msg_size: int = 4 * 1024 * 1024, + ) -> aiohttp.ClientWebSocketResponse: + """Open a raw WebSocket connection to Core.""" try: - client = await session.ws_connect(url, ssl=False) + return await session.ws_connect(url, ssl=False, max_msg_size=max_msg_size) except aiohttp.client_exceptions.ClientConnectorError: raise HomeAssistantWSConnectionError("Can't connect") from None - hello_message = await client.receive_json() + @classmethod + async def connect( + cls, + session: aiohttp.ClientSession, + url: str, + *, + max_msg_size: int = 4 * 1024 * 1024, + ) -> WSClient: + """Connect via Unix socket (no auth exchange). - await client.send_json( - {ATTR_TYPE: WSType.AUTH, ATTR_ACCESS_TOKEN: token}, dumps=json_dumps - ) + Core authenticates the peer by the socket connection itself + and sends auth_ok immediately. + """ + client = await cls._ws_connect(session, url, max_msg_size=max_msg_size) + try: + first_message = await client.receive_json() - auth_ok_message = await client.receive_json() + if first_message[ATTR_TYPE] != "auth_ok": + raise HomeAssistantAPIError( + f"Expected auth_ok on Unix socket, got {first_message[ATTR_TYPE]}" + ) - if auth_ok_message[ATTR_TYPE] != "auth_ok": - raise HomeAssistantAPIError("AUTH NOT OK") + return cls(AwesomeVersion(first_message["ha_version"]), client) + except HomeAssistantAPIError: + await client.close() + raise + except ( + KeyError, + ValueError, + TypeError, + aiohttp.ClientError, + TimeoutError, + ) as err: + await client.close() + raise HomeAssistantAPIError( + f"Unexpected error during WebSocket handshake: {err}" + ) from err - return cls(AwesomeVersion(hello_message["ha_version"]), client) + @classmethod + async def connect_with_auth( + cls, + session: aiohttp.ClientSession, + url: str, + token: str, + *, + max_msg_size: int = 4 * 1024 * 1024, + ) -> WSClient: + """Connect via TCP with token authentication. + + Expects auth_required from Core, sends the token, then expects auth_ok. + The auth_required message also carries ha_version. + """ + client = await cls._ws_connect(session, url, max_msg_size=max_msg_size) + try: + # auth_required message also carries ha_version + first_message = await client.receive_json() + + if first_message[ATTR_TYPE] != "auth_required": + raise HomeAssistantAPIError( + f"Expected auth_required, got {first_message[ATTR_TYPE]}" + ) + + await client.send_json( + {ATTR_TYPE: WSType.AUTH, ATTR_ACCESS_TOKEN: token}, dumps=json_dumps + ) + + auth_ok_message = await client.receive_json() + + if auth_ok_message[ATTR_TYPE] != "auth_ok": + raise HomeAssistantAPIError("AUTH NOT OK") + + return cls(AwesomeVersion(first_message["ha_version"]), client) + except HomeAssistantAPIError: + await client.close() + raise + except ( + KeyError, + ValueError, + TypeError, + aiohttp.ClientError, + TimeoutError, + ) as err: + await client.close() + raise HomeAssistantAPIError( + f"Unexpected error during WebSocket handshake: {err}" + ) from err class HomeAssistantWebSocket(CoreSysAttributes): @@ -168,7 +245,7 @@ class HomeAssistantWebSocket(CoreSysAttributes): def __init__(self, coresys: CoreSys): """Initialize Home Assistant object.""" self.coresys: CoreSys = coresys - self._client: WSClient | None = None + self.client: WSClient | None = None self._lock: asyncio.Lock = asyncio.Lock() self._queue: list[dict[str, Any]] = [] @@ -183,16 +260,10 @@ async def _process_queue(self, reference: CoreState) -> None: async def _get_ws_client(self) -> WSClient: """Return a websocket client.""" async with self._lock: - if self._client is not None and self._client.connected: - return self._client - - with suppress(asyncio.TimeoutError, aiohttp.ClientError): - await self.sys_homeassistant.api.ensure_access_token() - client = await WSClient.connect_with_auth( - self.sys_websession, - self.sys_homeassistant.ws_url, - cast(str, self.sys_homeassistant.api.access_token), - ) + if self.client is not None and self.client.connected: + return self.client + + client = await self.sys_homeassistant.api.connect_websocket() self.sys_create_task(client.start_listener()) return client @@ -208,7 +279,7 @@ async def _ensure_connected(self) -> None: "WebSocket not available, system is shutting down" ) - connected = self._client and self._client.connected + connected = self.client and self.client.connected # If we are already connected, we can avoid the check_api_state call # since it makes a new socket connection and we already have one. if not connected and not await self.sys_homeassistant.api.check_api_state(): @@ -216,8 +287,8 @@ async def _ensure_connected(self) -> None: "Can't connect to Home Assistant Core WebSocket, the API is not reachable" ) - if not self._client or not self._client.connected: - self._client = await self._get_ws_client() + if not self.client or not self.client.connected: + self.client = await self._get_ws_client() async def load(self) -> None: """Set up queue processor after startup completes.""" @@ -241,16 +312,16 @@ async def _async_send_command(self, message: dict[str, Any]) -> None: _LOGGER.debug("Can't send WebSocket command: %s", err) return - # _ensure_connected guarantees self._client is set - assert self._client + # _ensure_connected guarantees self.client is set + assert self.client try: - await self._client.async_send_command(message) + await self.client.async_send_command(message) except HomeAssistantWSConnectionError as err: _LOGGER.debug("Fire-and-forget WebSocket command failed: %s", err) - if self._client: - await self._client.close() - self._client = None + if self.client: + await self.client.close() + self.client = None async def async_send_command(self, message: dict[str, Any]) -> T: """Send a command and return the response. @@ -258,14 +329,14 @@ async def async_send_command(self, message: dict[str, Any]) -> T: Raises HomeAssistantWSError on WebSocket connection or communication failure. """ await self._ensure_connected() - # _ensure_connected guarantees self._client is set - assert self._client + # _ensure_connected guarantees self.client is set + assert self.client try: - return await self._client.async_send_command(message) + return await self.client.async_send_command(message) except HomeAssistantWSConnectionError: - if self._client: - await self._client.close() - self._client = None + if self.client: + await self.client.close() + self.client = None raise def send_command(self, message: dict[str, Any]) -> None: diff --git a/supervisor/ingress.py b/supervisor/ingress.py index 859d8586613..04428824447 100644 --- a/supervisor/ingress.py +++ b/supervisor/ingress.py @@ -185,12 +185,7 @@ async def del_dynamic_port(self, addon_slug: str) -> None: await self.save_data() async def update_hass_panel(self, addon: Addon): - """Return True if Home Assistant up and running.""" - if not await self.sys_homeassistant.core.is_running(): - _LOGGER.debug("Ignoring panel update on Core") - return - - # Update UI + """Update the ingress panel registration in Home Assistant.""" method = "post" if addon.ingress_panel else "delete" try: async with self.sys_homeassistant.api.make_request( diff --git a/tests/addons/test_manager.py b/tests/addons/test_manager.py index 04b8d2a9cf4..ba0fb7289ee 100644 --- a/tests/addons/test_manager.py +++ b/tests/addons/test_manager.py @@ -246,7 +246,7 @@ async def test_addon_uninstall_removes_discovery( assert message.service == "mqtt" assert coresys.discovery.list_messages == [message] - coresys.homeassistant.api.ensure_access_token = AsyncMock() + coresys.homeassistant.api._ensure_access_token = AsyncMock() # pylint: disable=protected-access await coresys.addons.uninstall(TEST_ADDON_SLUG) await asyncio.sleep(0) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 132dbab548f..bcdac6f97c3 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -88,7 +88,7 @@ async def test_password_reset( websession: MagicMock, ): """Test password reset api.""" - coresys.homeassistant.api.access_token = "abc123" + coresys.homeassistant.api._access_token = "abc123" # pylint: disable=protected-access # pylint: disable-next=protected-access coresys.homeassistant.api._access_token_expires = datetime.now(tz=UTC) + timedelta( days=1 @@ -124,7 +124,7 @@ async def test_failed_password_reset( expected_log: str, ): """Test failed password reset.""" - coresys.homeassistant.api.access_token = "abc123" + coresys.homeassistant.api._access_token = "abc123" # pylint: disable=protected-access # pylint: disable-next=protected-access coresys.homeassistant.api._access_token_expires = datetime.now(tz=UTC) + timedelta( days=1 diff --git a/tests/api/test_discovery.py b/tests/api/test_discovery.py index fc2d4850c02..5da57fcfc3f 100644 --- a/tests/api/test_discovery.py +++ b/tests/api/test_discovery.py @@ -91,7 +91,7 @@ async def test_api_send_del_discovery( ): """Test adding and removing discovery.""" install_addon_ssh.data["discovery"] = ["test"] - coresys.homeassistant.api.ensure_access_token = AsyncMock() + coresys.homeassistant.api._ensure_access_token = AsyncMock() # pylint: disable=protected-access resp = await api_client.post("/discovery", json={"service": "test", "config": {}}) assert resp.status == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 0774105ec06..8be38ef02e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -556,11 +556,11 @@ async def coresys( Path(__file__).parent.joinpath("fixtures"), "apparmor" ) - # WebSocket + # Home Assistant Core API coresys_obj.homeassistant.api.get_api_state = AsyncMock( return_value=APIState("RUNNING", False) ) - coresys_obj.homeassistant._websocket._client = AsyncMock( + coresys_obj.homeassistant._websocket.client = AsyncMock( ha_version=AwesomeVersion("2021.2.4") ) @@ -580,7 +580,7 @@ async def ha_ws_client(coresys: CoreSys) -> AsyncMock: # Set Supervisor Core state to RUNNING, otherwise WS events won't be delivered await coresys.core.set_state(CoreState.RUNNING) await asyncio.sleep(0) - client = coresys.homeassistant.websocket._client + client = coresys.homeassistant.websocket.client client.async_send_command.reset_mock() return client @@ -707,8 +707,13 @@ def supervisor_internet(coresys: CoreSys) -> Generator[AsyncMock]: @pytest.fixture def websession(coresys: CoreSys) -> Generator[MagicMock]: - """Fixture for global aiohttp SessionClient.""" + """Fixture for global aiohttp SessionClient. + + Also mocks Core container is_running to return True so that + make_request doesn't bail before reaching the websession. + """ coresys._websession = MagicMock(spec_set=ClientSession) + coresys.homeassistant.core.instance.is_running = AsyncMock(return_value=True) yield coresys._websession diff --git a/tests/docker/test_homeassistant.py b/tests/docker/test_homeassistant.py index c8a3e64d7b4..423b719a760 100644 --- a/tests/docker/test_homeassistant.py +++ b/tests/docker/test_homeassistant.py @@ -9,6 +9,7 @@ from supervisor.coresys import CoreSys from supervisor.docker.const import ( + MOUNT_CORE_RUN, DockerMount, MountBindOptions, MountType, @@ -24,7 +25,7 @@ @pytest.mark.usefixtures("tmp_supervisor_data", "path_extern") async def test_homeassistant_start(coresys: CoreSys, container: DockerContainer): """Test starting homeassistant.""" - coresys.homeassistant.version = AwesomeVersion("2023.8.1") + coresys.homeassistant.version = AwesomeVersion("2026.4.0") with ( patch.object(DockerAPI, "run", return_value=container.show.return_value) as run, @@ -51,7 +52,7 @@ async def test_homeassistant_start(coresys: CoreSys, container: DockerContainer) "TZ": ANY, "SUPERVISOR_TOKEN": ANY, "HASSIO_TOKEN": ANY, - # no "HA_DUPLICATE_LOG_FILE" + "SUPERVISOR_CORE_API_SOCKET": "/run/supervisor/core.sock", } assert run.call_args.kwargs["mounts"] == [ DEV_MOUNT, @@ -117,6 +118,7 @@ async def test_homeassistant_start(coresys: CoreSys, container: DockerContainer) target="/etc/machine-id", read_only=True, ), + MOUNT_CORE_RUN, ] assert "volumes" not in run.call_args.kwargs @@ -144,6 +146,28 @@ async def test_homeassistant_start_with_duplicate_log_file( assert env["HA_DUPLICATE_LOG_FILE"] == "1" +@pytest.mark.usefixtures("tmp_supervisor_data", "path_extern") +async def test_homeassistant_start_with_unix_socket( + coresys: CoreSys, container: DockerContainer +): + """Test starting homeassistant with unix socket env var for supported version.""" + coresys.homeassistant.version = AwesomeVersion("2026.4.0") + + with ( + patch.object(DockerAPI, "run", return_value=container.show.return_value) as run, + patch.object( + DockerHomeAssistant, "is_running", side_effect=[False, False, True] + ), + patch("supervisor.homeassistant.core.asyncio.sleep"), + ): + await coresys.homeassistant.core.start() + + run.assert_called_once() + env = run.call_args.kwargs["environment"] + assert "SUPERVISOR_CORE_API_SOCKET" in env + assert env["SUPERVISOR_CORE_API_SOCKET"] == "/run/supervisor/core.sock" + + @pytest.mark.usefixtures("tmp_supervisor_data", "path_extern") async def test_landingpage_start(coresys: CoreSys, container: DockerContainer): """Test starting landingpage.""" diff --git a/tests/homeassistant/test_api.py b/tests/homeassistant/test_api.py index 35125ced3f1..9971d3b9342 100644 --- a/tests/homeassistant/test_api.py +++ b/tests/homeassistant/test_api.py @@ -1,65 +1,39 @@ """Test Home Assistant API.""" from contextlib import asynccontextmanager -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from aiohttp import hdrs from awesomeversion import AwesomeVersion import pytest from supervisor.coresys import CoreSys +from supervisor.docker.const import ContainerState +from supervisor.docker.monitor import DockerContainerStateEvent from supervisor.exceptions import HomeAssistantAPIError +from supervisor.homeassistant.api import APIState, HomeAssistantAPI +from supervisor.homeassistant.const import LANDINGPAGE +from tests.common import MockResponse -async def test_check_frontend_available_success(coresys: CoreSys): - """Test frontend availability check succeeds with valid HTML response.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - - mock_response = MagicMock() - mock_response.status = 200 - mock_response.headers = {hdrs.CONTENT_TYPE: "text/html; charset=utf-8"} - - @asynccontextmanager - async def mock_make_request(*args, **kwargs): - yield mock_response - - with patch.object( - type(coresys.homeassistant.api), "make_request", new=mock_make_request - ): - result = await coresys.homeassistant.api.check_frontend_available() - - assert result is True - +# --- check_frontend_available --- -async def test_check_frontend_available_wrong_status(coresys: CoreSys): - """Test frontend availability check fails with non-200 status.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - - mock_response = MagicMock() - mock_response.status = 404 - mock_response.headers = {hdrs.CONTENT_TYPE: "text/html"} - - @asynccontextmanager - async def mock_make_request(*args, **kwargs): - yield mock_response - with patch.object( - type(coresys.homeassistant.api), "make_request", new=mock_make_request - ): - result = await coresys.homeassistant.api.check_frontend_available() - - assert result is False - - -async def test_check_frontend_available_wrong_content_type( - coresys: CoreSys, caplog: pytest.LogCaptureFixture +@pytest.mark.parametrize( + ("status", "content_type", "expected"), + [ + (200, "text/html; charset=utf-8", True), + (404, "text/html", False), + (200, "application/json", False), + ], +) +async def test_check_frontend_available( + coresys: CoreSys, status: int, content_type: str, expected: bool ): - """Test frontend availability check fails with wrong content type.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - + """Test frontend availability based on HTTP status and content type.""" mock_response = MagicMock() - mock_response.status = 200 - mock_response.headers = {hdrs.CONTENT_TYPE: "application/json"} + mock_response.status = status + mock_response.headers = {hdrs.CONTENT_TYPE: content_type} @asynccontextmanager async def mock_make_request(*args, **kwargs): @@ -68,15 +42,11 @@ async def mock_make_request(*args, **kwargs): with patch.object( type(coresys.homeassistant.api), "make_request", new=mock_make_request ): - result = await coresys.homeassistant.api.check_frontend_available() - - assert result is False - assert "unexpected content type" in caplog.text + assert await coresys.homeassistant.api.check_frontend_available() is expected async def test_check_frontend_available_api_error(coresys: CoreSys): """Test frontend availability check handles API errors gracefully.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") @asynccontextmanager async def mock_make_request(*args, **kwargs): @@ -86,15 +56,14 @@ async def mock_make_request(*args, **kwargs): with patch.object( type(coresys.homeassistant.api), "make_request", new=mock_make_request ): - result = await coresys.homeassistant.api.check_frontend_available() + assert await coresys.homeassistant.api.check_frontend_available() is False - assert result is False + +# --- get_config / get_core_state --- async def test_get_config_success(coresys: CoreSys): """Test get_config returns valid config dictionary.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - expected_config = { "latitude": 32.87336, "longitude": -117.22743, @@ -113,11 +82,7 @@ async def test_get_config_success(coresys: CoreSys): mock_response = MagicMock() mock_response.status = 200 - - async def mock_json(): - return expected_config - - mock_response.json = mock_json + mock_response.json = AsyncMock(return_value=expected_config) @asynccontextmanager async def mock_make_request(*_args, **_kwargs): @@ -126,22 +91,24 @@ async def mock_make_request(*_args, **_kwargs): with patch.object( type(coresys.homeassistant.api), "make_request", new=mock_make_request ): - result = await coresys.homeassistant.api.get_config() - - assert result == expected_config - - -async def test_get_config_returns_none(coresys: CoreSys): - """Test get_config raises error when None is returned.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - + assert await coresys.homeassistant.api.get_config() == expected_config + + +@pytest.mark.parametrize( + ("method", "bad_response", "match"), + [ + ("get_config", None, "No config received"), + ("get_config", ["not", "a", "dict"], "No config received"), + ("get_core_state", None, "No state received"), + ], +) +async def test_get_json_validation( + coresys: CoreSys, method: str, bad_response, match: str +): + """Test get_config/get_core_state raise on invalid responses.""" mock_response = MagicMock() mock_response.status = 200 - - async def mock_json(): - return None - - mock_response.json = mock_json + mock_response.json = AsyncMock(return_value=bad_response) @asynccontextmanager async def mock_make_request(*_args, **_kwargs): @@ -151,24 +118,14 @@ async def mock_make_request(*_args, **_kwargs): patch.object( type(coresys.homeassistant.api), "make_request", new=mock_make_request ), - pytest.raises( - HomeAssistantAPIError, match="No config received from Home Assistant API" - ), + pytest.raises(HomeAssistantAPIError, match=match), ): - await coresys.homeassistant.api.get_config() - + await getattr(coresys.homeassistant.api, method)() -async def test_get_config_returns_non_dict(coresys: CoreSys): - """Test get_config raises error when non-dict is returned.""" - coresys.homeassistant.version = AwesomeVersion("2025.8.0") - mock_response = MagicMock() - mock_response.status = 200 - - async def mock_json(): - return ["not", "a", "dict"] - - mock_response.json = mock_json +async def test_get_config_api_error(coresys: CoreSys): + """Test get_config propagates API errors.""" + mock_response = MagicMock(status=500) @asynccontextmanager async def mock_make_request(*_args, **_kwargs): @@ -178,30 +135,301 @@ async def mock_make_request(*_args, **_kwargs): patch.object( type(coresys.homeassistant.api), "make_request", new=mock_make_request ), - pytest.raises( - HomeAssistantAPIError, match="No config received from Home Assistant API" - ), + pytest.raises(HomeAssistantAPIError, match="500"), ): await coresys.homeassistant.api.get_config() -async def test_get_config_api_error(coresys: CoreSys): - """Test get_config propagates API errors from underlying _get_json call.""" +# --- supports_unix_socket / use_unix_socket --- + + +@pytest.mark.parametrize( + ("version", "expected"), + [ + ("2026.4.0", True), + ("2024.1.0", False), + (LANDINGPAGE, False), + ], +) +async def test_supports_unix_socket(coresys: CoreSys, version: str, expected: bool): + """Test supports_unix_socket based on Core version.""" + coresys.homeassistant.version = AwesomeVersion(version) + assert coresys.homeassistant.api.supports_unix_socket is expected + + +@pytest.mark.parametrize( + ("version", "env", "expected"), + [ + ("2024.1.0", [], False), + ("2026.4.0", ["SUPERVISOR_CORE_API_SOCKET=/run/supervisor/core.sock"], True), + ("2026.4.0", ["TZ=UTC", "SUPERVISOR_TOKEN=abc"], False), + ], +) +async def test_use_unix_socket( + coresys: CoreSys, version: str, env: list[str], expected: bool +): + """Test use_unix_socket based on version and container env.""" + coresys.homeassistant.version = AwesomeVersion(version) + # pylint: disable-next=protected-access + coresys.homeassistant.core.instance._meta = {"Config": {"Env": env}} + assert coresys.homeassistant.api.use_unix_socket is expected + + +# --- api_url / ws_url --- + + +@pytest.mark.parametrize( + ("use_unix", "expected_api_url", "expected_ws_url"), + [ + (True, "http://localhost", "ws://localhost/api/websocket"), + (False, "http://172.30.32.1:8123", "ws://172.30.32.1:8123/api/websocket"), + ], +) +async def test_api_and_ws_urls( + coresys: CoreSys, use_unix: bool, expected_api_url: str, expected_ws_url: str +): + """Test api_url and ws_url for Unix socket and TCP transports.""" + with patch.object(type(coresys.homeassistant.api), "use_unix_socket", use_unix): + assert coresys.homeassistant.api.api_url == expected_api_url + assert coresys.homeassistant.api.ws_url == expected_ws_url + + +# --- connection lifecycle --- + + +@pytest.fixture +def real_get_api_state(coresys: CoreSys): + """Restore real get_api_state (coresys fixture mocks it).""" + api = coresys.homeassistant.api + api.get_api_state = type(api).get_api_state.__get__(api) + return api + + +async def test_connected_log_after_container_restart( + coresys: CoreSys, + real_get_api_state: HomeAssistantAPI, + caplog: pytest.LogCaptureFixture, +): + """Test 'Connected to Core' log reappears after container stop and reconnect.""" + api = coresys.homeassistant.api coresys.homeassistant.version = AwesomeVersion("2025.8.0") + api.get_core_state = AsyncMock( + return_value={"state": "RUNNING", "recorder_state": {}} + ) + + # First connection logs + with patch.object(type(api), "use_unix_socket", False): + await api.get_api_state() + assert "Connected to Core via TCP" in caplog.text + + # Container stops + caplog.clear() + await api.container_state_changed( + DockerContainerStateEvent( + name="homeassistant", + state=ContainerState.STOPPED, + id="abc123", + time=1234567890, + ) + ) + + # Reconnect logs again + with patch.object(type(api), "use_unix_socket", False): + await api.get_api_state() + assert "Connected to Core via TCP" in caplog.text + + +async def test_container_state_changed_ignores_other_containers( + coresys: CoreSys, + real_get_api_state: HomeAssistantAPI, + caplog: pytest.LogCaptureFixture, +): + """Test container_state_changed ignores events from other containers.""" + api = coresys.homeassistant.api + coresys.homeassistant.version = AwesomeVersion("2025.8.0") + api.get_core_state = AsyncMock( + return_value={"state": "RUNNING", "recorder_state": {}} + ) + + # First connection + with patch.object(type(api), "use_unix_socket", False): + await api.get_api_state() + assert "Connected to Core via TCP" in caplog.text + + # Other container stops — should not reset + caplog.clear() + await api.container_state_changed( + DockerContainerStateEvent( + name="addon_local_ssh", + state=ContainerState.STOPPED, + id="abc123", + time=1234567890, + ) + ) + + with patch.object(type(api), "use_unix_socket", False): + await api.get_api_state() + # Should NOT log again since connection state wasn't reset + assert "Connected to Core" not in caplog.text + + +# --- get_api_state / check_api_state --- + + +@pytest.mark.parametrize( + ("version", "core_state_response", "expected_state", "expected_check"), + [ + (LANDINGPAGE, None, None, False), + (None, None, None, False), + ( + "2025.8.0", + {"state": "RUNNING", "recorder_state": {}}, + APIState("RUNNING", False), + True, + ), + ( + "2025.8.0", + {"state": "NOT_RUNNING", "recorder_state": {}}, + APIState("NOT_RUNNING", False), + False, + ), + ( + "2025.8.0", + HomeAssistantAPIError("Connection failed"), + None, + False, + ), + ], +) +async def test_get_api_state( + coresys: CoreSys, + real_get_api_state: HomeAssistantAPI, + version: str | None, + core_state_response: dict | Exception | None, + expected_state: APIState | None, + expected_check: bool, +): + """Test get_api_state and check_api_state for various scenarios.""" + coresys.homeassistant.version = ( + AwesomeVersion(version) if version and version != LANDINGPAGE else version + ) + if isinstance(core_state_response, Exception): + coresys.homeassistant.api.get_core_state = AsyncMock( + side_effect=core_state_response + ) + elif core_state_response is not None: + coresys.homeassistant.api.get_core_state = AsyncMock( + return_value=core_state_response + ) - mock_response = MagicMock() - mock_response.status = 500 + with patch.object(type(coresys.homeassistant.api), "use_unix_socket", False): + assert await coresys.homeassistant.api.get_api_state() == expected_state + assert await coresys.homeassistant.api.check_api_state() is expected_check + + +# --- make_request --- + + +async def test_make_request_not_running(coresys: CoreSys): + """Test make_request raises when Core container is not running.""" + coresys.homeassistant.core.instance.is_running = AsyncMock(return_value=False) + + with pytest.raises(HomeAssistantAPIError, match="not running"): + async with coresys.homeassistant.api.make_request("get", "api/test"): + pass + + +@pytest.mark.usefixtures("websession") +async def test_make_request_tcp_with_token_fetch(coresys: CoreSys): + """Test make_request fetches token via /auth/token and makes the request.""" + api = coresys.homeassistant.api + + # Mock /auth/token POST + token_resp = MockResponse() + token_resp.json = AsyncMock( + return_value={"access_token": "test_token", "expires_in": 1800} + ) + coresys.websession.post = MagicMock(return_value=token_resp) + + # Mock the actual API request + api_resp = MagicMock(status=200) @asynccontextmanager - async def mock_make_request(*_args, **_kwargs): - yield mock_response + async def mock_request(*_args, **_kwargs): + yield api_resp + + coresys.websession.request = mock_request + + with patch.object(type(api), "use_unix_socket", False): + async with api.make_request("get", "api/test") as resp: + assert resp.status == 200 + + # Verify token was fetched + coresys.websession.post.assert_called_once() + + +@pytest.mark.usefixtures("websession") +async def test_make_request_tcp_timeout(coresys: CoreSys): + """Test make_request wraps TimeoutError.""" + api = coresys.homeassistant.api + coresys.websession.request = MagicMock(side_effect=TimeoutError("timed out")) with ( - patch.object( - type(coresys.homeassistant.api), "make_request", new=mock_make_request - ), - pytest.raises( - HomeAssistantAPIError, match="Home Assistant Core API return 500" - ), + patch.object(type(api), "use_unix_socket", False), + patch.object(api, "_ensure_access_token", new_callable=AsyncMock), + pytest.raises(HomeAssistantAPIError, match="timed out"), ): - await coresys.homeassistant.api.get_config() + async with api.make_request("get", "api/test"): + pass + + +# --- connect_websocket --- + + +async def test_connect_websocket_unix(coresys: CoreSys): + """Test connect_websocket uses WSClient.connect for Unix socket.""" + coresys.homeassistant.core.instance.is_running = AsyncMock(return_value=True) + mock_ws_client = MagicMock() + with ( + patch.object(type(coresys.homeassistant.api), "use_unix_socket", True), + patch( + "supervisor.homeassistant.api.WSClient.connect", + new_callable=AsyncMock, + return_value=mock_ws_client, + ) as mock_connect, + ): + result = await coresys.homeassistant.api.connect_websocket() + + assert result is mock_ws_client + mock_connect.assert_called_once() + + +@pytest.mark.usefixtures("websession") +async def test_connect_websocket_tcp(coresys: CoreSys): + """Test connect_websocket fetches token and connects with auth for TCP.""" + api = coresys.homeassistant.api + mock_ws_client = MagicMock() + + # Mock the /auth/token endpoint to return a valid token + token_resp = MockResponse() + token_resp.json = AsyncMock( + return_value={"access_token": "fresh_token", "expires_in": 1800} + ) + coresys.websession.post = MagicMock(return_value=token_resp) + + with ( + patch.object(type(api), "use_unix_socket", False), + patch( + "supervisor.homeassistant.api.WSClient.connect_with_auth", + new_callable=AsyncMock, + return_value=mock_ws_client, + ) as mock_connect, + ): + result = await api.connect_websocket() + + assert result is mock_ws_client + # Verify token was fetched + coresys.websession.post.assert_called_once() + # Verify connect_with_auth was called with the fresh token + mock_connect.assert_called_once() + assert mock_connect.call_args.args[2] == "fresh_token" diff --git a/tests/homeassistant/test_module.py b/tests/homeassistant/test_module.py index dec68cb8704..a2e636f4b8e 100644 --- a/tests/homeassistant/test_module.py +++ b/tests/homeassistant/test_module.py @@ -87,8 +87,7 @@ async def test_write_pulse_error(coresys: CoreSys, caplog: pytest.LogCaptureFixt async def test_begin_backup_ws_error(coresys: CoreSys): """Test WS error when beginning backup.""" - # pylint: disable-next=protected-access - coresys.homeassistant.websocket._client.async_send_command.side_effect = ( + coresys.homeassistant.websocket.client.async_send_command.side_effect = ( HomeAssistantWSConnectionError("Connection was closed") ) with ( @@ -103,8 +102,7 @@ async def test_begin_backup_ws_error(coresys: CoreSys): async def test_end_backup_ws_error(coresys: CoreSys, caplog: pytest.LogCaptureFixture): """Test WS error when ending backup.""" - # pylint: disable-next=protected-access - coresys.homeassistant.websocket._client.async_send_command.side_effect = ( + coresys.homeassistant.websocket.client.async_send_command.side_effect = ( HomeAssistantWSConnectionError("Connection was closed") ) with patch.object(HomeAssistantWebSocket, "_ensure_connected", return_value=None): diff --git a/tests/homeassistant/test_websocket.py b/tests/homeassistant/test_websocket.py index acc58209582..2ee8b2d33cd 100644 --- a/tests/homeassistant/test_websocket.py +++ b/tests/homeassistant/test_websocket.py @@ -2,14 +2,16 @@ # pylint: disable=import-error import asyncio -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from supervisor.const import CoreState from supervisor.coresys import CoreSys -from supervisor.exceptions import HomeAssistantWSConnectionError +from supervisor.exceptions import HomeAssistantAPIError, HomeAssistantWSConnectionError from supervisor.homeassistant.const import WSEvent, WSType +from supervisor.homeassistant.websocket import WSClient async def test_send_command(coresys: CoreSys, ha_ws_client: AsyncMock): @@ -106,3 +108,153 @@ async def test_send_command_during_shutdown(coresys: CoreSys, ha_ws_client: Asyn await coresys.homeassistant.websocket.async_send_command({"type": "test"}) ha_ws_client.async_send_command.assert_not_called() + + +# --- WSClient --- + + +def _mock_ws_client(messages: list[dict]) -> MagicMock: + """Create a mock aiohttp WebSocket client that returns messages in sequence.""" + client = AsyncMock(spec=aiohttp.ClientWebSocketResponse) + client.receive_json = AsyncMock(side_effect=messages) + client.send_json = AsyncMock() + client.close = AsyncMock() + client.closed = False + return client + + +async def test_ws_connect_error(): + """Test _ws_connect wraps ClientConnectorError.""" + session = AsyncMock() + session.ws_connect = AsyncMock( + side_effect=aiohttp.ClientConnectorError( + MagicMock(), OSError("Connection refused") + ) + ) + + with pytest.raises(HomeAssistantWSConnectionError, match="Can't connect"): + await WSClient._ws_connect(session, "ws://localhost/api/websocket") + + +async def test_connect_unix_success(): + """Test WSClient.connect succeeds with auth_ok.""" + session = AsyncMock() + ws = _mock_ws_client([{"type": "auth_ok", "ha_version": "2026.4.0"}]) + session.ws_connect = AsyncMock(return_value=ws) + + client = await WSClient.connect(session, "ws://localhost/api/websocket") + assert client.ha_version == "2026.4.0" + assert client.connected is True + ws.close.assert_not_called() + + +async def test_connect_unix_unexpected_message(): + """Test WSClient.connect raises and closes on unexpected message.""" + session = AsyncMock() + ws = _mock_ws_client([{"type": "auth_required", "ha_version": "2026.4.0"}]) + session.ws_connect = AsyncMock(return_value=ws) + + with pytest.raises(HomeAssistantAPIError, match="Expected auth_ok"): + await WSClient.connect(session, "ws://localhost/api/websocket") + ws.close.assert_called_once() + + +async def test_connect_unix_bad_json(): + """Test WSClient.connect wraps ValueError from bad JSON.""" + session = AsyncMock() + ws = AsyncMock(spec=aiohttp.ClientWebSocketResponse) + ws.receive_json = AsyncMock(side_effect=ValueError("bad json")) + ws.close = AsyncMock() + session.ws_connect = AsyncMock(return_value=ws) + + with pytest.raises(HomeAssistantAPIError, match="Unexpected error"): + await WSClient.connect(session, "ws://localhost/api/websocket") + ws.close.assert_called_once() + + +async def test_connect_with_auth_success(): + """Test WSClient.connect_with_auth succeeds with auth handshake.""" + session = AsyncMock() + ws = _mock_ws_client( + [ + {"type": "auth_required", "ha_version": "2026.4.0"}, + {"type": "auth_ok", "ha_version": "2026.4.0"}, + ] + ) + session.ws_connect = AsyncMock(return_value=ws) + + client = await WSClient.connect_with_auth( + session, "ws://localhost/api/websocket", "test_token" + ) + assert client.ha_version == "2026.4.0" + ws.send_json.assert_called_once() + ws.close.assert_not_called() + + +async def test_connect_with_auth_unexpected_first_message(): + """Test connect_with_auth raises on unexpected first message.""" + session = AsyncMock() + ws = _mock_ws_client([{"type": "auth_ok", "ha_version": "2026.4.0"}]) + session.ws_connect = AsyncMock(return_value=ws) + + with pytest.raises(HomeAssistantAPIError, match="Expected auth_required"): + await WSClient.connect_with_auth( + session, "ws://localhost/api/websocket", "test_token" + ) + ws.close.assert_called_once() + + +async def test_connect_with_auth_rejected(): + """Test connect_with_auth raises on auth rejection.""" + session = AsyncMock() + ws = _mock_ws_client( + [ + {"type": "auth_required", "ha_version": "2026.4.0"}, + {"type": "auth_invalid", "message": "Invalid password"}, + ] + ) + session.ws_connect = AsyncMock(return_value=ws) + + with pytest.raises(HomeAssistantAPIError, match="AUTH NOT OK"): + await WSClient.connect_with_auth( + session, "ws://localhost/api/websocket", "bad_token" + ) + ws.close.assert_called_once() + + +async def test_connect_with_auth_missing_key(): + """Test connect_with_auth wraps KeyError from missing keys.""" + session = AsyncMock() + ws = _mock_ws_client([{"no_type_key": "oops"}]) + session.ws_connect = AsyncMock(return_value=ws) + + with pytest.raises(HomeAssistantAPIError, match="Unexpected error"): + await WSClient.connect_with_auth( + session, "ws://localhost/api/websocket", "token" + ) + ws.close.assert_called_once() + + +async def test_ws_client_close(): + """Test WSClient.close cancels pending futures and closes connection.""" + ws = AsyncMock(spec=aiohttp.ClientWebSocketResponse) + ws.closed = False + ws.close = AsyncMock() + + client = WSClient.__new__(WSClient) + client.ha_version = "2026.4.0" + client.client = ws + client._message_id = 0 + client._futures = {} + + # Add a pending future + loop = asyncio.get_running_loop() + future = loop.create_future() + client._futures[1] = future + + await client.close() + + assert future.done() + with pytest.raises(HomeAssistantWSConnectionError): + future.result() + ws.close.assert_called_once()