Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions homeassistant/components/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from homeassistant.config import async_hass_config_yaml
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
from homeassistant.core import callback
from homeassistant.helpers import service
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.translation import async_get_translations
from homeassistant.loader import bind_hass
Expand Down Expand Up @@ -103,19 +104,6 @@

SERVICE_SET_THEME = "set_theme"
SERVICE_RELOAD_THEMES = "reload_themes"
SERVICE_SET_THEME_SCHEMA = vol.Schema({vol.Required(CONF_NAME): cv.string})
WS_TYPE_GET_PANELS = "get_panels"
SCHEMA_GET_PANELS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_GET_PANELS}
)
WS_TYPE_GET_THEMES = "frontend/get_themes"
SCHEMA_GET_THEMES = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_GET_THEMES}
)
WS_TYPE_GET_TRANSLATIONS = "frontend/get_translations"
SCHEMA_GET_TRANSLATIONS = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{vol.Required("type"): WS_TYPE_GET_TRANSLATIONS, vol.Required("language"): str}
)


class Panel:
Expand Down Expand Up @@ -251,15 +239,9 @@ def _frontend_root(dev_repo_path):
async def async_setup(hass, config):
"""Set up the serving of the frontend."""
await async_setup_frontend_storage(hass)
hass.components.websocket_api.async_register_command(
WS_TYPE_GET_PANELS, websocket_get_panels, SCHEMA_GET_PANELS
)
hass.components.websocket_api.async_register_command(
WS_TYPE_GET_THEMES, websocket_get_themes, SCHEMA_GET_THEMES
)
hass.components.websocket_api.async_register_command(
WS_TYPE_GET_TRANSLATIONS, websocket_get_translations, SCHEMA_GET_TRANSLATIONS
)
hass.components.websocket_api.async_register_command(websocket_get_panels)
hass.components.websocket_api.async_register_command(websocket_get_themes)
hass.components.websocket_api.async_register_command(websocket_get_translations)
hass.http.register_view(ManifestJSONView)

conf = config.get(DOMAIN, {})
Expand Down Expand Up @@ -331,11 +313,7 @@ async def async_setup(hass, config):
def _async_setup_themes(hass, themes):
"""Set up themes data and services."""
hass.data[DATA_DEFAULT_THEME] = DEFAULT_THEME
if themes is None:
hass.data[DATA_THEMES] = {}
return

hass.data[DATA_THEMES] = themes
hass.data[DATA_THEMES] = themes or {}

@callback
def update_theme_and_fire_event():
Expand All @@ -348,9 +326,7 @@ def update_theme_and_fire_event():
"app-header-background-color",
themes[name].get(PRIMARY_COLOR, DEFAULT_THEME_COLOR),
)
hass.bus.async_fire(
EVENT_THEMES_UPDATED, {"themes": themes, "default_theme": name}
)
hass.bus.async_fire(EVENT_THEMES_UPDATED)

@callback
def set_theme(call):
Expand All @@ -373,10 +349,17 @@ async def reload_themes(_):
hass.data[DATA_DEFAULT_THEME] = DEFAULT_THEME
update_theme_and_fire_event()

hass.services.async_register(
DOMAIN, SERVICE_SET_THEME, set_theme, schema=SERVICE_SET_THEME_SCHEMA
service.async_register_admin_service(
hass,
DOMAIN,
SERVICE_SET_THEME,
set_theme,
vol.Schema({vol.Required(CONF_NAME): cv.string}),
)

service.async_register_admin_service(
hass, DOMAIN, SERVICE_RELOAD_THEMES, reload_themes
)
hass.services.async_register(DOMAIN, SERVICE_RELOAD_THEMES, reload_themes)


class IndexView(web_urldispatcher.AbstractResource):
Expand Down Expand Up @@ -498,6 +481,7 @@ def get(self, request): # pylint: disable=no-self-use


@callback
@websocket_api.websocket_command({"type": "get_panels"})
def websocket_get_panels(hass, connection, msg):
"""Handle get panels command.

Expand All @@ -514,6 +498,7 @@ def websocket_get_panels(hass, connection, msg):


@callback
@websocket_api.websocket_command({"type": "frontend/get_themes"})
def websocket_get_themes(hass, connection, msg):
"""Handle get themes command.

Expand All @@ -530,6 +515,9 @@ def websocket_get_themes(hass, connection, msg):
)


@websocket_api.websocket_command(
{"type": "frontend/get_translations", vol.Required("language"): str}
)
@websocket_api.async_response
async def websocket_get_translations(hass, connection, msg):
"""Handle get translations command.
Expand Down
16 changes: 10 additions & 6 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,10 @@ def async_add_job(

if asyncio.iscoroutine(check_target):
task = self.loop.create_task(target) # type: ignore
elif is_callback(check_target):
self.loop.call_soon(target, *args)
elif asyncio.iscoroutinefunction(check_target):
task = self.loop.create_task(target(*args))
elif is_callback(check_target):
self.loop.call_soon(target, *args)
else:
task = self.loop.run_in_executor( # type: ignore
None, target, *args
Expand Down Expand Up @@ -360,7 +360,11 @@ def async_run_job(self, target: Callable[..., None], *args: Any) -> None:
target: target to call.
args: parameters for method to call.
"""
if not asyncio.iscoroutine(target) and is_callback(target):
if (
not asyncio.iscoroutine(target)
and not asyncio.iscoroutinefunction(target)
and is_callback(target)
):
target(*args)
else:
self.async_add_job(target, *args)
Expand Down Expand Up @@ -1245,10 +1249,10 @@ async def _execute_service(
self, handler: Service, service_call: ServiceCall
) -> None:
"""Execute a service."""
if handler.is_callback:
handler.func(service_call)
elif handler.is_coroutinefunction:
if handler.is_coroutinefunction:
await handler.func(service_call)
elif handler.is_callback:
handler.func(service_call)
else:
await self._hass.async_add_executor_job(handler.func, service_call)

Expand Down
4 changes: 3 additions & 1 deletion homeassistant/helpers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,9 @@ async def admin_handler(call):
if not user.is_admin:
raise Unauthorized(context=call.context)

await hass.async_add_job(service_func, call)
result = hass.async_add_job(service_func, call)
if result is not None:
await result

hass.services.async_register(domain, service, admin_handler, schema)

Expand Down
25 changes: 25 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,3 +1180,28 @@ def test_context():
assert c.user_id == 23
assert c.parent_id == 100
assert c.id is not None


async def test_async_functions_with_callback(hass):
"""Test we deal with async functions accidentally marked as callback."""
runs = []

@ha.callback
async def test():
runs.append(True)

await hass.async_add_job(test)
assert len(runs) == 1

hass.async_run_job(test)
await hass.async_block_till_done()
assert len(runs) == 2

@ha.callback
async def service_handler(call):
runs.append(True)

hass.services.async_register("test_domain", "test_service", service_handler)

await hass.services.async_call("test_domain", "test_service", blocking=True)
assert len(runs) == 3