Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .strict-typing
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ homeassistant.components.camera.*
homeassistant.components.canary.*
homeassistant.components.clickatell.*
homeassistant.components.clicksend.*
homeassistant.components.configurator.*
homeassistant.components.cover.*
homeassistant.components.cpuspeed.*
homeassistant.components.crownstone.*
Expand Down
42 changes: 25 additions & 17 deletions homeassistant/components/configurator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from contextlib import suppress
from datetime import datetime
import functools as ft
from typing import Any, cast
from typing import Any

from homeassistant.const import ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant, ServiceCall, callback as async_callback
Expand Down Expand Up @@ -80,7 +80,7 @@ def async_request_config(
if DATA_REQUESTS not in hass.data:
hass.data[DATA_REQUESTS] = {}

hass.data[DATA_REQUESTS][request_id] = instance
_get_requests(hass)[request_id] = instance

return request_id

Expand All @@ -98,10 +98,10 @@ def request_config(hass: HomeAssistant, *args: Any, **kwargs: Any) -> str:

@bind_hass
@async_callback
def async_notify_errors(hass, request_id, error):
def async_notify_errors(hass: HomeAssistant, request_id: str, error: str) -> None:
"""Add errors to a config request."""
with suppress(KeyError): # If request_id does not exist
hass.data[DATA_REQUESTS][request_id].async_notify_errors(request_id, error)
_get_requests(hass)[request_id].async_notify_errors(request_id, error)


@bind_hass
Expand All @@ -117,7 +117,7 @@ def notify_errors(hass: HomeAssistant, request_id: str, error: str) -> None:
def async_request_done(hass: HomeAssistant, request_id: str) -> None:
"""Mark a configuration request as done."""
with suppress(KeyError): # If request_id does not exist
hass.data[DATA_REQUESTS].pop(request_id).async_request_done(request_id)
_get_requests(hass).pop(request_id).async_request_done(request_id)


@bind_hass
Expand All @@ -133,10 +133,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True


def _get_requests(hass: HomeAssistant) -> dict[str, Configurator]:
"""Return typed configurator_requests data."""
return hass.data[DATA_REQUESTS] # type: ignore[no-any-return]


class Configurator:
"""The class to keep track of current configuration requests."""

def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the configurator."""
self.hass = hass
self._cur_id = 0
Expand Down Expand Up @@ -190,22 +195,23 @@ def async_request_config(
return request_id

@async_callback
def async_notify_errors(self, request_id, error):
def async_notify_errors(self, request_id: str, error: str) -> None:
"""Update the state with errors."""
if not self._validate_request_id(request_id):
return

entity_id = self._requests[request_id][0]

state = self.hass.states.get(entity_id)
if (state := self.hass.states.get(entity_id)) is None:
return

new_data = dict(state.attributes)
new_data[ATTR_ERRORS] = error

self.hass.states.async_set(entity_id, STATE_CONFIGURE, new_data)

@async_callback
def async_request_done(self, request_id):
def async_request_done(self, request_id: str) -> None:
"""Remove the configuration request."""
if not self._validate_request_id(request_id):
return
Expand All @@ -219,30 +225,32 @@ def async_request_done(self, request_id):
self.hass.states.async_set(entity_id, STATE_CONFIGURED)

@async_callback
def deferred_remove(now: datetime):
def deferred_remove(now: datetime) -> None:
"""Remove the request state."""
self.hass.states.async_remove(entity_id)

async_call_later(self.hass, 1, deferred_remove)

async def async_handle_service_call(self, call: ServiceCall) -> None:
"""Handle a configure service call."""
request_id = call.data.get(ATTR_CONFIGURE_ID)
request_id: str | None = call.data.get(ATTR_CONFIGURE_ID)

if not self._validate_request_id(request_id):
if not request_id or not self._validate_request_id(request_id):
return

_, _, callback = self._requests[cast(str, request_id)]
_, _, callback = self._requests[request_id]

# field validation goes here?
if callback:
await self.hass.async_add_job(callback, call.data.get(ATTR_FIELDS, {}))
if callback and (
job := self.hass.async_add_job(callback, call.data.get(ATTR_FIELDS, {}))
):
await job

def _generate_unique_id(self):
def _generate_unique_id(self) -> str:
"""Generate a unique configurator ID."""
self._cur_id += 1
return f"{id(self)}-{self._cur_id}"

def _validate_request_id(self, request_id):
def _validate_request_id(self, request_id: str) -> bool:
"""Validate that the request belongs to this instance."""
return request_id in self._requests
10 changes: 10 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true

[mypy-homeassistant.components.configurator.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true

[mypy-homeassistant.components.cover.*]
check_untyped_defs = true
disallow_incomplete_defs = true
Expand Down
19 changes: 12 additions & 7 deletions tests/components/configurator/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import homeassistant.components.configurator as configurator
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant
import homeassistant.util.dt as dt_util

from tests.common import async_fire_time_changed


async def test_request_least_info(hass):
async def test_request_least_info(hass: HomeAssistant) -> None:
"""Test request config with least amount of data."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)

Expand All @@ -27,7 +28,7 @@ async def test_request_least_info(hass):
assert state.attributes.get(configurator.ATTR_CONFIGURE_ID) == request_id


async def test_request_all_info(hass):
async def test_request_all_info(hass: HomeAssistant) -> None:
"""Test request config with all possible info."""
exp_attr = {
ATTR_FRIENDLY_NAME: "Test Request",
Expand Down Expand Up @@ -61,7 +62,7 @@ async def test_request_all_info(hass):
assert state.attributes == exp_attr


async def test_callback_called_on_configure(hass):
async def test_callback_called_on_configure(hass: HomeAssistant) -> None:
"""Test if our callback gets called when configure service called."""
calls = []
request_id = configurator.async_request_config(
Expand All @@ -78,7 +79,7 @@ async def test_callback_called_on_configure(hass):
assert len(calls) == 1, "Callback not called"


async def test_state_change_on_notify_errors(hass):
async def test_state_change_on_notify_errors(hass: HomeAssistant) -> None:
"""Test state change on notify errors."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)
error = "Oh no bad bad bad"
Expand All @@ -90,12 +91,14 @@ async def test_state_change_on_notify_errors(hass):
assert state.attributes.get(configurator.ATTR_ERRORS) == error


async def test_notify_errors_fail_silently_on_bad_request_id(hass):
async def test_notify_errors_fail_silently_on_bad_request_id(
hass: HomeAssistant,
) -> None:
"""Test if notify errors fails silently with a bad request id."""
configurator.async_notify_errors(hass, 2015, "Try this error")


async def test_request_done_works(hass):
async def test_request_done_works(hass: HomeAssistant) -> None:
"""Test if calling request done works."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)
configurator.async_request_done(hass, request_id)
Expand All @@ -105,6 +108,8 @@ async def test_request_done_works(hass):
assert len(hass.states.async_all()) == 0


async def test_request_done_fail_silently_on_bad_request_id(hass):
async def test_request_done_fail_silently_on_bad_request_id(
hass: HomeAssistant,
) -> None:
"""Test that request_done fails silently with a bad request id."""
configurator.async_request_done(hass, 2016)