-
-
Notifications
You must be signed in to change notification settings - Fork 37.7k
Bayesian Binary Sensor #8810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bayesian Binary Sensor #8810
Changes from 8 commits
1db8d87
ec0deb4
ad4a609
7b3bf06
39465c1
51e94de
843a8f9
ffcc838
a679a9b
ce28829
29ca298
066514d
a44bc31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,217 @@ | ||
| """ | ||
| Use Bayesian Inference to trigger a binary sensor. | ||
|
|
||
| For more details about this platform, please refer to the documentation at | ||
| https://home-assistant.io/components/binary_sensor.bayesian/ | ||
| """ | ||
| import asyncio | ||
| import logging | ||
| from collections import OrderedDict | ||
|
|
||
| import voluptuous as vol | ||
|
|
||
| import homeassistant.helpers.config_validation as cv | ||
| from homeassistant.components.binary_sensor import (BinarySensorDevice, | ||
| PLATFORM_SCHEMA) | ||
| from homeassistant.const import (CONF_ABOVE, CONF_BELOW, CONF_DEVICE_CLASS, | ||
| CONF_ENTITY_ID, CONF_NAME, CONF_PLATFORM, | ||
| CONF_STATE, STATE_UNKNOWN) | ||
| from homeassistant.core import callback | ||
| from homeassistant.helpers import condition | ||
| from homeassistant.helpers.event import async_track_state_change | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
| CONF_OBSERVATIONS = 'observations' | ||
| CONF_PRIOR = 'prior' | ||
| CONF_PROBABILITY_THRESHOLD = 'probability_threshold' | ||
| CONF_P_GIVEN_F = 'prob_given_false' | ||
| CONF_P_GIVEN_T = 'prob_given_true' | ||
| CONF_TO_STATE = 'to_state' | ||
|
|
||
| DEFAULT_NAME = 'BayesianBinary' | ||
|
|
||
| NUMERIC_STATE_SCHEMA = vol.Schema( | ||
| { | ||
| CONF_PLATFORM: 'numeric_state', | ||
| CONF_ENTITY_ID: cv.entity_id, | ||
| vol.Optional(CONF_ABOVE): vol.Coerce(float), | ||
| vol.Optional(CONF_BELOW): vol.Coerce(float), | ||
| CONF_P_GIVEN_T: vol.Coerce(float), | ||
| vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float) | ||
| }, | ||
| required=True) | ||
|
|
||
| STATE_SCHEMA = vol.Schema( | ||
| { | ||
| CONF_PLATFORM: CONF_STATE, | ||
| CONF_ENTITY_ID: cv.entity_id, | ||
| CONF_TO_STATE: cv.string, | ||
| CONF_P_GIVEN_T: vol.Coerce(float), | ||
| vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float) | ||
| }, | ||
| required=True) | ||
|
|
||
| PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ | ||
| vol.Optional(CONF_NAME, default=DEFAULT_NAME): | ||
| cv.string, | ||
| vol.Optional(CONF_DEVICE_CLASS): | ||
| cv.string, | ||
| vol.Required(CONF_OBSERVATIONS): | ||
| vol.Schema( | ||
| vol.All(cv.ensure_list, [vol.Any(NUMERIC_STATE_SCHEMA, | ||
| STATE_SCHEMA)]) | ||
| ), | ||
| vol.Required(CONF_PRIOR): | ||
| vol.Coerce(float), | ||
| vol.Optional(CONF_PROBABILITY_THRESHOLD): | ||
| vol.Coerce(float), | ||
| }) | ||
|
|
||
|
|
||
| def update_probability(prior, prob_true, prob_false): | ||
| """Update probability using Bayes' rule.""" | ||
| numerator = prob_true * prior | ||
| denominator = numerator + prob_false * (1 - prior) | ||
|
|
||
| probability = numerator / denominator | ||
|
|
||
| return probability | ||
|
|
||
|
|
||
| @asyncio.coroutine | ||
| def async_setup_platform(hass, config, async_add_devices, discovery_info=None): | ||
| """Set up the Threshold sensor.""" | ||
| name = config.get(CONF_NAME) | ||
| observations = config.get(CONF_OBSERVATIONS) | ||
| prior = config.get(CONF_PRIOR) | ||
| probability_threshold = config.get(CONF_PROBABILITY_THRESHOLD, 0.5) | ||
| device_class = config.get(CONF_DEVICE_CLASS) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| async_add_devices([ | ||
| BayesianBinarySensor(name, prior, observations, probability_threshold, | ||
| device_class) | ||
| ], True) | ||
|
|
||
|
|
||
| class BayesianBinarySensor(BinarySensorDevice): | ||
| """Representation of a Bayesian sensor.""" | ||
|
|
||
| def __init__(self, name, prior, observations, probability_threshold, | ||
| device_class): | ||
| """Initialize the Bayesian sensor.""" | ||
| self._name = name | ||
| self._observations = observations | ||
| self._probability_threshold = probability_threshold | ||
| self._device_class = device_class | ||
| self._deviation = False | ||
| self.prior = prior | ||
| self.probability = prior | ||
|
|
||
| self.current_obs = OrderedDict({}) | ||
|
|
||
| self.entity_obs = {obs['entity_id']: obs for obs in self._observations} | ||
|
|
||
| self.watchers = { | ||
| 'numeric_state': self._process_numeric_state, | ||
| 'state': self._process_state | ||
| } | ||
|
|
||
| @asyncio.coroutine | ||
| def async_added_to_hass(self): | ||
| """Call when entity about to be added to hass.""" | ||
| @callback | ||
| # pylint: disable=invalid-name | ||
| def async_threshold_sensor_state_listener(entity, old_state, | ||
| new_state): | ||
| """Handle sensor state changes.""" | ||
| if new_state.state == STATE_UNKNOWN: | ||
| return | ||
|
|
||
| entity_obs = self.entity_obs[entity] | ||
| platform = entity_obs['platform'] | ||
|
|
||
| self.watchers[platform](entity_obs) | ||
|
|
||
| prior = self.prior | ||
| print(self.current_obs.values()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove debug print. |
||
| for obs in self.current_obs.values(): | ||
| prior = update_probability(prior, obs['prob_true'], | ||
| obs['prob_false']) | ||
|
|
||
| self.probability = prior | ||
|
|
||
| self.hass.async_add_job(self.async_update_ha_state, True) | ||
|
|
||
| for obs in self._observations: | ||
| entity_id = obs['entity_id'] | ||
| async_track_state_change(self.hass, entity_id, | ||
| async_threshold_sensor_state_listener) | ||
|
|
||
| def _update_current_obs(self, entity_observation, should_trigger): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a function documentation. Add decorator "@callback" if they need run inside async loop or add "Async friendly" to function documentation if that dosn't matters.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I basically adapted the As a side note, the I appreciate the feedback, and hope to address it the best way possible.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I make a PR to remove that from Yeah, the function can run async/sync. So but "Async friendly" to this function like https://github.com/home-assistant/home-assistant/blob/c56f99baafac33483dd13699993d7da4ee5d7efd/homeassistant/helpers/location.py#L11-L14. A async/sync problem is only visible on real world and not every time inside tests. It is very hard to find a problem like this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 thanks for the clarification. |
||
| entity = entity_observation['entity_id'] | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove blank line. |
||
| if should_trigger: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code between the |
||
| prob_true = entity_observation['prob_given_true'] | ||
| prob_false = entity_observation.get('prob_given_false', | ||
| 1 - prob_true) | ||
|
|
||
| self.current_obs[entity] = { | ||
| 'prob_true': prob_true, | ||
| 'prob_false': prob_false | ||
| } | ||
|
|
||
| else: | ||
| self.current_obs.pop(entity, None) | ||
|
|
||
| def _process_numeric_state(self, entity_observation): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a function documentation. Add decorator "@callback" if they need run inside async loop or add "Async friendly" to function documentation if that dosn't matters. |
||
| entity = entity_observation['entity_id'] | ||
|
|
||
| should_trigger = condition.async_numeric_state( | ||
| self.hass, entity, | ||
| entity_observation.get('below'), | ||
| entity_observation.get('above'), None, entity_observation) | ||
|
|
||
| self._update_current_obs(entity_observation, should_trigger) | ||
|
|
||
| def _process_state(self, entity_observation): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a function documentation. Add decorator "@callback" if they need run inside async loop or add "Async friendly" to function documentation if that dosn't matters. |
||
| entity = entity_observation['entity_id'] | ||
|
|
||
| should_trigger = condition.state(self.hass, entity, | ||
| entity_observation.get('to_state')) | ||
|
|
||
| self._update_current_obs(entity_observation, should_trigger) | ||
|
|
||
| @property | ||
| def name(self): | ||
| """Return the name of the sensor.""" | ||
| return self._name | ||
|
|
||
| @property | ||
| def is_on(self): | ||
| """Return true if sensor is on.""" | ||
| return self._deviation | ||
|
|
||
| @property | ||
| def should_poll(self): | ||
| """No polling needed.""" | ||
| return False | ||
|
|
||
| @property | ||
| def device_class(self): | ||
| """Return the sensor class of the sensor.""" | ||
| return self._device_class | ||
|
|
||
| @property | ||
| def device_state_attributes(self): | ||
| """Return the state attributes of the sensor.""" | ||
| return { | ||
| 'observations': [val for val in self.current_obs.values()], | ||
| 'probability': self.probability, | ||
| 'probability_threshold': self._probability_threshold | ||
| } | ||
|
|
||
| @asyncio.coroutine | ||
| def async_update(self): | ||
| """Get the latest data and update the states.""" | ||
| self._deviation = bool(self.probability > self._probability_threshold) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| """The test for the bayesian sensor platform.""" | ||
| import unittest | ||
|
|
||
| from homeassistant.setup import setup_component | ||
| from homeassistant.components.binary_sensor import bayesian | ||
|
|
||
| from tests.common import get_test_home_assistant | ||
|
|
||
|
|
||
| class TestBayesianBinarySensor(unittest.TestCase): | ||
| """Test the threshold sensor.""" | ||
|
|
||
| def setup_method(self, method): | ||
| """Set up things to be run when tests are started.""" | ||
| self.hass = get_test_home_assistant() | ||
|
|
||
| def teardown_method(self, method): | ||
| """Stop everything that was started.""" | ||
| self.hass.stop() | ||
|
|
||
| def test_sensor_numeric_state(self): | ||
| """Test sensor on numeric state platform observations.""" | ||
| config = { | ||
| 'binary_sensor': { | ||
| 'platform': | ||
| 'bayesian', | ||
| 'name': | ||
| 'Test_Binary', | ||
| 'observations': [{ | ||
| 'platform': 'numeric_state', | ||
| 'entity_id': 'sensor.test_monitored', | ||
| 'below': 10, | ||
| 'above': 5, | ||
| 'prob_given_true': 0.6 | ||
| }, { | ||
| 'platform': 'numeric_state', | ||
| 'entity_id': 'sensor.test_monitored1', | ||
| 'below': 7, | ||
| 'above': 5, | ||
| 'prob_given_true': 0.9, | ||
| 'prob_given_false': 0.1 | ||
| }], | ||
| 'prior': | ||
| 0.2, | ||
| } | ||
| } | ||
|
|
||
| assert setup_component(self.hass, 'binary_sensor', config) | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 4) | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
|
|
||
| self.assertEqual([], state.attributes.get('observations')) | ||
| self.assertEqual(0.2, state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'off' | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 6) | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 4) | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 6) | ||
| self.hass.states.set('sensor.test_monitored1', 6) | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
| self.assertEqual([{ | ||
| 'prob_false': 0.4, | ||
| 'prob_true': 0.6 | ||
| }, { | ||
| 'prob_false': 0.1, | ||
| 'prob_true': 0.9 | ||
| }], state.attributes.get('observations')) | ||
| self.assertAlmostEqual(0.7714285714285715, | ||
| state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'on' | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 6) | ||
| self.hass.states.set('sensor.test_monitored1', 0) | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 4) | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
| self.assertEqual(0.2, state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'off' | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 15) | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
|
|
||
| assert state.state == 'off' | ||
|
|
||
| def test_sensor_state(self): | ||
| """Test sensor on state platform observations.""" | ||
| config = { | ||
| 'binary_sensor': { | ||
| 'name': | ||
| 'Test_Binary', | ||
| 'platform': | ||
| 'bayesian', | ||
| 'observations': [{ | ||
| 'platform': 'state', | ||
| 'entity_id': 'sensor.test_monitored', | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. continuation line under-indented for visual indent |
||
| 'to_state': 'off', | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. continuation line under-indented for visual indent |
||
| 'prob_given_true': 0.8, | ||
| 'prob_given_false': 0.4 | ||
| }], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. closing bracket does not match visual indentation |
||
| 'prior': | ||
| 0.2, | ||
| 'probability_threshold': | ||
| 0.32, | ||
| } | ||
| } | ||
|
|
||
| assert setup_component(self.hass, 'binary_sensor', config) | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 'on') | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
|
|
||
| self.assertEqual([], state.attributes.get('observations')) | ||
| self.assertEqual(0.2, state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'off' | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 'off') | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 'on') | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 'off') | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
| self.assertEqual([{ | ||
| 'prob_true': 0.8, | ||
| 'prob_false': 0.4 | ||
| }], state.attributes.get('observations')) | ||
| self.assertAlmostEqual(0.33333333, state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'on' | ||
|
|
||
| self.hass.states.set('sensor.test_monitored', 'off') | ||
| self.hass.block_till_done() | ||
| self.hass.states.set('sensor.test_monitored', 'on') | ||
| self.hass.block_till_done() | ||
|
|
||
| state = self.hass.states.get('binary_sensor.test_binary') | ||
| self.assertAlmostEqual(0.2, state.attributes.get('probability')) | ||
|
|
||
| assert state.state == 'off' | ||
|
|
||
| def test_probability_updates(self): | ||
| """Test probability update function.""" | ||
| prob_true = [0.3, 0.6, 0.8] | ||
| prob_false = [0.7, 0.4, 0.2] | ||
| prior = 0.5 | ||
|
|
||
| for pt, pf in zip(prob_true, prob_false): | ||
| prior = bayesian.update_probability(prior, pt, pf) | ||
|
|
||
| self.assertAlmostEqual(0.720000, prior) | ||
|
|
||
| prob_true = [0.8, 0.3, 0.9] | ||
| prob_false = [0.6, 0.4, 0.2] | ||
| prior = 0.7 | ||
|
|
||
| for pt, pf in zip(prob_true, prob_false): | ||
| prior = bayesian.update_probability(prior, pt, pf) | ||
|
|
||
| self.assertAlmostEqual(0.9130434782608695, prior) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'homeassistant.const.CONF_STATE' imported but unused