Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions homeassistant/components/recorder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def _process_one_event(self, event):
if event.event_type == EVENT_STATE_CHANGED:
try:
dbstate = States.from_event(event)
dbstate_attributes = StateAttributes.from_event(event)
shared_attrs = StateAttributes.shared_attrs_from_event(event)
except (TypeError, ValueError) as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
Expand All @@ -995,27 +995,33 @@ def _process_one_event(self, event):
return

dbstate.attributes = None
shared_attrs = dbstate_attributes.shared_attrs
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
dbstate.state_attributes = pending_attributes
# Matching attributes id found in the cache
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id
# Matching attributes found in the database
elif (
attributes := self.event_session.query(StateAttributes.attributes_id)
.filter(StateAttributes.hash == dbstate_attributes.hash)
.filter(StateAttributes.shared_attrs == shared_attrs)
.first()
):
dbstate.attributes_id = attributes[0]
self._state_attributes_ids[shared_attrs] = attributes[0]
# No matching attributes found, save them in the DB
else:
dbstate.state_attributes = dbstate_attributes
self._pending_state_attributes[shared_attrs] = dbstate_attributes
self.event_session.add(dbstate_attributes)
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
# Matching attributes found in the database
if (
attributes := self.event_session.query(
StateAttributes.attributes_id
)
.filter(StateAttributes.hash == attr_hash)
.filter(StateAttributes.shared_attrs == shared_attrs)
.first()
):
dbstate.attributes_id = attributes[0]
self._state_attributes_ids[shared_attrs] = attributes[0]
# No matching attributes found, save them in the DB
else:
dbstate_attributes = StateAttributes(
shared_attrs=shared_attrs, hash=attr_hash
)
dbstate.state_attributes = dbstate_attributes
self._pending_state_attributes[shared_attrs] = dbstate_attributes
self.event_session.add(dbstate_attributes)

if old_state := self._old_states.pop(dbstate.entity_id, None):
if old_state.state_id:
Expand Down
8 changes: 8 additions & 0 deletions homeassistant/components/recorder/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Recorder constants."""

from functools import partial
import json
from typing import Final

from homeassistant.helpers.json import JSONEncoder

DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
DOMAIN = "recorder"
Expand All @@ -17,3 +23,5 @@
MAX_ROWS_TO_PURGE = 998

DB_WORKER_PREFIX = "DbWorker"

JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))
56 changes: 30 additions & 26 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta
import json
import logging
from typing import TypedDict, overload
from typing import Any, TypedDict, overload

from fnvhash import fnv1a_32
from sqlalchemy import (
Expand Down Expand Up @@ -35,9 +35,10 @@
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSONEncoder
import homeassistant.util.dt as dt_util

from .const import JSON_DUMP

# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
Expand Down Expand Up @@ -116,8 +117,7 @@ def from_event(event, event_data=None):
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=event_data
or json.dumps(event.data, cls=JSONEncoder, separators=(",", ":")),
event_data=event_data or JSON_DUMP(event.data),
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
Expand Down Expand Up @@ -186,15 +186,13 @@ def __repr__(self) -> str:
)

@staticmethod
def from_event(event):
def from_event(event) -> States:
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state = event.data.get("new_state")

dbstate = States(entity_id=entity_id)
dbstate.attributes = None
state: State | None = event.data.get("new_state")
dbstate = States(entity_id=entity_id, attributes=None)

# State got deleted
# None state means the state was removed from the state machine
if state is None:
dbstate.state = ""
dbstate.domain = split_entity_id(entity_id)[0]
Expand All @@ -208,7 +206,7 @@ def from_event(event):

return dbstate

def to_native(self, validate_entity_id=True):
def to_native(self, validate_entity_id: bool = True) -> State | None:
"""Convert to an HA state object."""
try:
return State(
Expand All @@ -221,7 +219,7 @@ def to_native(self, validate_entity_id=True):
process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead
# as it will always be there for state_changed events
context=Context(id=None),
context=Context(id=None), # type: ignore[arg-type]
validate_entity_id=validate_entity_id,
)
except ValueError:
Expand Down Expand Up @@ -251,23 +249,29 @@ def __repr__(self) -> str:
)

@staticmethod
def from_event(event):
def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state = event.data.get("new_state")
dbstate = StateAttributes()
# State got deleted
if state is None:
dbstate.shared_attrs = "{}"
else:
dbstate.shared_attrs = json.dumps(
dict(state.attributes),
cls=JSONEncoder,
separators=(",", ":"),
)
dbstate.hash = fnv1a_32(dbstate.shared_attrs.encode("utf-8"))
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
return dbstate

def to_native(self):
@staticmethod
def shared_attrs_from_event(event: Event) -> str:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
return "{}" if state is None else JSON_DUMP(state.attributes)

@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return fnv1a_32(shared_attrs.encode("utf-8"))

def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return json.loads(self.shared_attrs)
Expand Down