Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 108 additions & 68 deletions homeassistant/components/bayesian/binary_sensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Use Bayesian Inference to trigger a binary sensor."""
from collections import OrderedDict
from itertools import chain

import voluptuous as vol

Expand Down Expand Up @@ -88,10 +87,10 @@
)


def update_probability(prior, prob_true, prob_false):
def update_probability(prior, prob_given_true, prob_given_false):
"""Update probability using Bayes' rule."""
numerator = prob_true * prior
denominator = numerator + prob_false * (1 - prior)
numerator = prob_given_true * prior
denominator = numerator + prob_given_false * (1 - prior)
probability = numerator / denominator
return probability

Expand Down Expand Up @@ -127,84 +126,124 @@ def __init__(self, name, prior, observations, probability_threshold, device_clas
self.prior = prior
self.probability = prior

self.current_obs = OrderedDict({})
self.entity_obs_dict = []
self.current_observations = OrderedDict({})

for obs in self._observations:
if "entity_id" in obs:
self.entity_obs_dict.append([obs.get("entity_id")])
if "value_template" in obs:
self.entity_obs_dict.append(
list(obs.get(CONF_VALUE_TEMPLATE).extract_entities())
)
self.observations_by_entity = self._build_observations_by_entity()

to_observe = set()
for obs in self._observations:
if "entity_id" in obs:
to_observe.update(set([obs.get("entity_id")]))
if "value_template" in obs:
to_observe.update(set(obs.get(CONF_VALUE_TEMPLATE).extract_entities()))
self.entity_obs = {key: [] for key in to_observe}

for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "entity_id" in obs:
self.entity_obs[obs["entity_id"]].append(obs)
if "value_template" in obs:
for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities():
self.entity_obs[ent].append(obs)

self.watchers = {
self.observation_handlers = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"template": self._process_template,
}

async def async_added_to_hass(self):
"""Call when entity about to be added."""
"""
Call when entity about to be added.

@callback
def async_threshold_sensor_state_listener(entity, old_state, new_state):
"""Handle sensor state changes."""
if new_state.state == STATE_UNKNOWN:
return
All relevant update logic for instance attributes occurs within this closure.
Other methods in this class are designed to avoid directly modifying instance
attributes, by instead focusing on returning relevant data back to this method.

entity_obs_list = self.entity_obs[entity]
The goal of this method is to ensure that `self.current_observations` and `self.probability`
are set on a best-effort basis when this entity is register with hass.

for entity_obs in entity_obs_list:
platform = entity_obs["platform"]
In addition, this method must register the state listener defined within, which
will be called any time a relevant entity changes its state.
"""

self.watchers[platform](entity_obs)
@callback
def async_threshold_sensor_state_listener(entity, _old_state, new_state):
"""
Handle sensor state changes.

prior = self.prior
for obs in self.current_obs.values():
prior = update_probability(prior, obs["prob_true"], obs["prob_false"])
self.probability = prior
When a state changes, we must update our list of current observations,
then calculate the new probability.
"""
if new_state.state == STATE_UNKNOWN:
return

self.current_observations.update(self._record_entity_observations(entity))
self.probability = self._calculate_new_probability()

self.hass.async_add_job(self.async_update_ha_state, True)

self.current_observations.update(self._initialize_current_observations())
self.probability = self._calculate_new_probability()
async_track_state_change(
self.hass, self.entity_obs, async_threshold_sensor_state_listener
self.hass,
self.observations_by_entity,
async_threshold_sensor_state_listener,
)

def _update_current_obs(self, entity_observation, should_trigger):
"""Update current observation."""
obs_id = entity_observation["id"]
def _initialize_current_observations(self):
local_observations = OrderedDict({})
for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity))
return local_observations

def _record_entity_observations(self, entity):
local_observations = OrderedDict({})
entity_obs_list = self.observations_by_entity[entity]

for entity_obs in entity_obs_list:
platform = entity_obs["platform"]

if should_trigger:
prob_true = entity_observation["prob_given_true"]
prob_false = entity_observation.get("prob_given_false", 1 - prob_true)
should_trigger = self.observation_handlers[platform](entity_obs)

self.current_obs[obs_id] = {
"prob_true": prob_true,
"prob_false": prob_false,
}
if should_trigger:
obs_entry = {"entity_id": entity, **entity_obs}
else:
obs_entry = None

else:
self.current_obs.pop(obs_id, None)
local_observations[entity_obs["id"]] = obs_entry

return local_observations

def _calculate_new_probability(self):
prior = self.prior

for obs in self.current_observations.values():
if obs is not None:
prior = update_probability(
prior,
obs["prob_given_true"],
obs.get("prob_given_false", 1 - obs["prob_given_true"]),
)

return prior

def _build_observations_by_entity(self):
"""
Build and return data structure of the form below.

{
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
"sensor.sensor2": [{"id": 2, ...}],
...
}

Each "observation" must be recognized uniquely, and it should be possible
for all relevant observations to be looked up via their `entity_id`.
"""

observations_by_entity = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind

if "entity_id" in obs:
entity_ids = [obs["entity_id"]]
elif "value_template" in obs:
entity_ids = obs.get(CONF_VALUE_TEMPLATE).extract_entities()

for e_id in entity_ids:
obs_list = observations_by_entity.get(e_id, [])
obs_list.append(obs)
observations_by_entity[e_id] = obs_list

return observations_by_entity

def _process_numeric_state(self, entity_observation):
"""Add entity to current_obs if numeric state conditions are met."""
"""Return True if numeric condition is met."""
entity = entity_observation["entity_id"]

should_trigger = condition.async_numeric_state(
Expand All @@ -215,27 +254,26 @@ def _process_numeric_state(self, entity_observation):
None,
entity_observation,
)

self._update_current_obs(entity_observation, should_trigger)
return should_trigger

def _process_state(self, entity_observation):
"""Add entity to current observations if state conditions are met."""
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]

should_trigger = condition.state(
self.hass, entity, entity_observation.get("to_state")
)

self._update_current_obs(entity_observation, should_trigger)
return should_trigger

def _process_template(self, entity_observation):
"""Add entity to current_obs if template is true."""
"""Return True if template condition is True."""
template = entity_observation.get(CONF_VALUE_TEMPLATE)
template.hass = self.hass
should_trigger = condition.async_template(
self.hass, template, entity_observation
)
self._update_current_obs(entity_observation, should_trigger)
return should_trigger

@property
def name(self):
Expand All @@ -260,13 +298,15 @@ def device_class(self):
@property
def device_state_attributes(self):
"""Return the state attributes of the sensor."""
print(self.current_observations)
print(self.observations_by_entity)
Comment on lines +301 to +302
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are two print statements.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind, just saw removed 5 hours ago in #33916

return {
ATTR_OBSERVATIONS: list(self.current_obs.values()),
ATTR_OBSERVATIONS: list(self.current_observations.values()),
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
set(
chain.from_iterable(
self.entity_obs_dict[obs] for obs in self.current_obs.keys()
)
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None
)
),
ATTR_PROBABILITY: round(self.probability, 2),
Expand Down
Loading