Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial orjson support take 2 #72847

Closed
wants to merge 15 commits into from
2 changes: 1 addition & 1 deletion homeassistant/components/history/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util

Expand Down
7 changes: 3 additions & 4 deletions homeassistant/components/http/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncio
from collections.abc import Awaitable, Callable
from http import HTTPStatus
import json
import logging
from typing import Any

Expand All @@ -21,7 +20,7 @@
from homeassistant import exceptions
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant.core import Context, is_callback
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSON_ENCODE_EXCEPTIONS, json_bytes

from .const import KEY_AUTHENTICATED, KEY_HASS

Expand Down Expand Up @@ -53,8 +52,8 @@ def json(
) -> web.Response:
"""Return a JSON response."""
try:
msg = json.dumps(result, cls=JSONEncoder, allow_nan=False).encode("UTF-8")
except (ValueError, TypeError) as err:
msg = json_bytes(result)
except JSON_ENCODE_EXCEPTIONS as err:
_LOGGER.error("Unable to serialize to JSON: %s\n%s", err, result)
raise HTTPInternalServerError from err
response = web.Response(
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/logbook/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from homeassistant.components.recorder import get_instance
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.entityfilter import EntityFilter
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP
import homeassistant.util.dt as dt_util

from .const import LOGBOOK_ENTITIES_FILTER
Expand Down
10 changes: 3 additions & 7 deletions homeassistant/components/recorder/const.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Recorder constants."""

from functools import partial
import json
from typing import Final

from homeassistant.backports.enum import StrEnum
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import
JSON_DUMP,
)

DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
Expand All @@ -27,8 +25,6 @@

DB_WORKER_PREFIX = "DbWorker"

JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))

ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES}

ATTR_KEEP_DAYS = "keep_days"
Expand Down
15 changes: 9 additions & 6 deletions homeassistant/components/recorder/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
async_track_time_interval,
async_track_utc_time_change,
)
from homeassistant.helpers.json import JSON_ENCODE_EXCEPTIONS
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util

Expand Down Expand Up @@ -754,19 +755,20 @@ def _process_non_state_changed_event_into_session(self, event: Event) -> None:
return

try:
shared_data = EventData.shared_data_from_event(event)
except (TypeError, ValueError) as ex:
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return

shared_data = shared_data_bytes.decode("utf-8")
# Matching attributes found in the pending commit
if pending_event_data := self._pending_event_data.get(shared_data):
dbevent.event_data_rel = pending_event_data
# Matching attributes id found in the cache
elif data_id := self._event_data_ids.get(shared_data):
dbevent.data_id = data_id
else:
data_hash = EventData.hash_shared_data(shared_data)
data_hash = EventData.hash_shared_data_bytes(shared_data_bytes)
# Matching attributes found in the database
if data_id := self._find_shared_data_in_db(data_hash, shared_data):
self._event_data_ids[shared_data] = dbevent.data_id = data_id
Expand All @@ -785,17 +787,18 @@ def _process_state_changed_event_into_session(self, event: Event) -> None:
assert self.event_session is not None
try:
dbstate = States.from_event(event)
shared_attrs = StateAttributes.shared_attrs_from_event(
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain
)
except (TypeError, ValueError) as ex:
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
ex,
)
return

shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
Expand All @@ -804,7 +807,7 @@ def _process_state_changed_event_into_session(self, event: Event) -> None:
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id
else:
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id
Expand Down
57 changes: 29 additions & 28 deletions homeassistant/components/recorder/db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
from typing import Any, cast

import ciso8601
from fnvhash import fnv1a_32
import orjson
from sqlalchemy import (
JSON,
BigInteger,
Expand Down Expand Up @@ -39,9 +39,10 @@
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util

from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
from .models import StatisticData, StatisticMetaData, process_timestamp

# SQLAlchemy Schema
Expand Down Expand Up @@ -124,7 +125,7 @@ def literal_processor(self, dialect: str) -> Callable[[Any], str]:

def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return JSON_DUMP(value)

return process

Expand Down Expand Up @@ -187,15 +188,15 @@ def to_native(self, validate_entity_id: bool = True) -> Event | None:
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
orjson.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None

Expand Down Expand Up @@ -223,25 +224,26 @@ def __repr__(self) -> str:
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
shared_data = json_bytes(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
shared_data=shared_data.decode("utf-8"),
hash=EventData.hash_shared_data_bytes(shared_data),
)

@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
def shared_data_bytes_from_event(event: Event) -> bytes:
"""Create shared_data from an event."""
return json_bytes(event.data)

@staticmethod
def hash_shared_data(shared_data: str) -> int:
def hash_shared_data_bytes(shared_data_bytes: bytes) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
return cast(int, fnv1a_32(shared_data_bytes))

def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
return cast(dict[str, Any], orjson.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
Expand Down Expand Up @@ -328,9 +330,9 @@ def to_native(self, validate_entity_id: bool = True) -> State | None:
parent_id=self.context_parent_id,
)
try:
attrs = json.loads(self.attributes) if self.attributes else {}
attrs = orjson.loads(self.attributes) if self.attributes else {}
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
if self.last_changed is None or self.last_changed == self.last_updated:
Expand Down Expand Up @@ -376,40 +378,39 @@ def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
attr_bytes = b"{}" if state is None else json_bytes(state.attributes)
dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8"))
dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes)
return dbstate

@staticmethod
def shared_attrs_from_event(
def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
) -> bytes:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
return b"{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
return json_bytes(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)

@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int:
"""Return the hash of orjson encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs_bytes))

def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
return cast(dict[str, Any], orjson.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}

Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from __future__ import annotations

from datetime import datetime
import json
import logging
from typing import Any, TypedDict, overload

import orjson
from sqlalchemy.engine.row import Row

from homeassistant.components.websocket_api.const import (
Expand Down Expand Up @@ -253,7 +253,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT:
return {}
try:
attr_cache[source] = attributes = json.loads(source)
attr_cache[source] = attributes = orjson.loads(source)
except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {}
Expand Down
18 changes: 9 additions & 9 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
Expand Down Expand Up @@ -241,13 +241,13 @@ def handle_get_states(
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
Expand All @@ -256,13 +256,13 @@ def handle_get_states(
serialized = []
for state in states:
try:
serialized.append(const.JSON_DUMP(state))
serialized.append(JSON_DUMP(state))
except (ValueError, TypeError):
# Error is already logged above
pass

# We now have partially serialized states. Craft some JSON.
response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)

Expand Down Expand Up @@ -315,13 +315,13 @@ def forward_entity_changes(event: Event) -> None:
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
Expand All @@ -330,14 +330,14 @@ def forward_entity_changes(event: Event) -> None:
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
try:
const.JSON_DUMP(state_dict)
JSON_DUMP(state_dict)
except (ValueError, TypeError):
cannot_serialize.append(entity_id)

for entity_id in cannot_serialize:
del add_entities[entity_id]

connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))


@decorators.websocket_command({vol.Required("type"): "get_services"})
Expand Down
Loading