Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import voluptuous as vol

from homeassistant.setup import async_prepare_setup_platform
from homeassistant.core import CoreState
from homeassistant.core import CoreState, Context
from homeassistant.loader import bind_hass
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
Expand Down Expand Up @@ -280,15 +280,21 @@ async def async_trigger(self, variables, skip_condition=False,

This method is a coroutine.
"""
if skip_condition or self._cond_func(variables):
self.async_set_context(context)
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
ATTR_NAME: self._name,
ATTR_ENTITY_ID: self.entity_id,
}, context=context)
await self._async_action(self.entity_id, variables, context)
self._last_triggered = utcnow()
await self.async_update_ha_state()
if not skip_condition and not self._cond_func(variables):
return

# Create a new context referring to the old context.
parent_id = None if context is None else context.id
trigger_context = Context(parent_id=parent_id)

self.async_set_context(trigger_context)
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
ATTR_NAME: self._name,
ATTR_ENTITY_ID: self.entity_id,
}, context=trigger_context)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall use context=context here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we want the automation triggered to be with the new context, that way we will be able to attach automation metadata to the context.

await self._async_action(self.entity_id, variables, trigger_context)
self._last_triggered = utcnow()
await self.async_update_ha_state()

async def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS."""
Expand Down
9 changes: 9 additions & 0 deletions homeassistant/components/recorder/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def _apply_update(engine, new_version, old_version):
_create_index(engine, "states", "ix_states_context_user_id")
elif new_version == 7:
_create_index(engine, "states", "ix_states_entity_id")
elif new_version == 8:
# Pending migration, want to group a few.
pass
# _add_columns(engine, "events", [
# 'context_parent_id CHARACTER(36)',
# ])
# _add_columns(engine, "states", [
# 'context_parent_id CHARACTER(36)',
# ])
else:
raise ValueError("No schema migration defined for version {}"
.format(new_version))
Expand Down
18 changes: 12 additions & 6 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@ class Events(Base): # type: ignore
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
# context_parent_id = Column(String(36), index=True)

@staticmethod
def from_event(event):
"""Create an event database object from a native event."""
return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id)
return Events(
event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
# context_parent_id=event.context.parent_id,
)

def to_native(self):
"""Convert to a natve HA Event."""
Expand Down Expand Up @@ -81,6 +85,7 @@ class States(Base): # type: ignore
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
# context_parent_id = Column(String(36), index=True)

__table_args__ = (
# Used for fetching the state of entities at a specific time
Expand All @@ -99,6 +104,7 @@ def from_event(event):
entity_id=entity_id,
context_id=event.context.id,
context_user_id=event.context.user_id,
# context_parent_id=event.context.parent_id,
)

# State got deleted
Expand Down
5 changes: 5 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ class Context:
type=str,
default=None,
)
parent_id = attr.ib(
type=Optional[str],
default=None
)
id = attr.ib(
type=str,
default=attr.Factory(lambda: uuid.uuid4().hex),
Expand All @@ -418,6 +422,7 @@ def as_dict(self) -> dict:
"""Return a dictionary representation of the context."""
return {
'id': self.id,
'parent_id': self.parent_id,
'user_id': self.user_id,
}

Expand Down
2 changes: 1 addition & 1 deletion tests/components/automation/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_if_fires_on_event(hass, calls):
hass.bus.async_fire('test_event', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id

await common.async_turn_off(hass)
await hass.async_block_till_done()
Expand Down
4 changes: 2 additions & 2 deletions tests/components/automation/test_geo_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
await hass.async_block_till_done()

assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'geo_location - geo_location.entity - hello - hello - test' == \
calls[0].data['some']

Expand Down Expand Up @@ -221,7 +221,7 @@ async def test_if_fires_on_zone_appear(hass, calls):
await hass.async_block_till_done()

assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'geo_location - geo_location.entity - - hello - test' == \
calls[0].data['some']

Expand Down
39 changes: 24 additions & 15 deletions tests/components/automation/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,38 +369,47 @@ async def test_shared_context(hass, calls):
})

context = Context()
automation_mock = Mock()
first_automation_listener = Mock()
event_mock = Mock()

hass.bus.async_listen('test_event2', automation_mock)
hass.bus.async_listen('test_event2', first_automation_listener)
hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock)
hass.bus.async_fire('test_event', context=context)
await hass.async_block_till_done()

# Ensure events was fired
assert automation_mock.call_count == 1
assert first_automation_listener.call_count == 1
assert event_mock.call_count == 2

# Ensure context carries through the event
args, kwargs = automation_mock.call_args
assert args[0].context == context
# Verify automation triggered evenet for 'hello' automation
args, kwargs = event_mock.call_args_list[0]
first_trigger_context = args[0].context
assert first_trigger_context.parent_id == context.id
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None

for call in event_mock.call_args_list:
args, kwargs = call
assert args[0].context == context
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None
# Ensure context set correctly for event fired by 'hello' automation
args, kwargs = first_automation_listener.call_args
assert args[0].context is first_trigger_context

# Ensure the automation state shares the same context
# Ensure the 'hello' automation state has the right context
state = hass.states.get('automation.hello')
assert state is not None
assert state.context == context
assert state.context is first_trigger_context

# Verify automation triggered evenet for 'bye' automation
args, kwargs = event_mock.call_args_list[1]
second_trigger_context = args[0].context
assert second_trigger_context.parent_id == first_trigger_context.id
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None

# Ensure the service call from the second automation
# shares the same context
assert len(calls) == 1
assert calls[0].context == context
assert calls[0].context is second_trigger_context


async def test_services(hass, calls):
Expand Down
4 changes: 2 additions & 2 deletions tests/components/automation/test_numeric_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_if_fires_on_entity_change_below(hass, calls):
hass.states.async_set('test.entity', 9, context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id

# Set above 12 so the automation will fire again
hass.states.async_set('test.entity', 12)
Expand Down Expand Up @@ -134,7 +134,7 @@ async def test_if_not_fires_on_entity_change_below_to_below(hass, calls):
hass.states.async_set('test.entity', 9, context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id

# already below so should not fire again
hass.states.async_set('test.entity', 5)
Expand Down
2 changes: 1 addition & 1 deletion tests/components/automation/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def test_if_fires_on_entity_change(hass, calls):
hass.states.async_set('test.entity', 'world', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'state - test.entity - hello - world - None' == \
calls[0].data['some']

Expand Down
2 changes: 1 addition & 1 deletion tests/components/automation/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def test_if_fires_on_change_with_template_advanced(hass, calls):
hass.states.async_set('test.entity', 'world', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'template - test.entity - hello - world' == \
calls[0].data['some']

Expand Down
2 changes: 1 addition & 1 deletion tests/components/automation/test_zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
await hass.async_block_till_done()

assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'zone - test.entity - hello - hello - test' == \
calls[0].data['some']

Expand Down
14 changes: 14 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def test_as_dict(self):
'time_fired': now,
'context': {
'id': event.context.id,
'parent_id': None,
'user_id': event.context.user_id,
},
}
Expand Down Expand Up @@ -1061,3 +1062,16 @@ def callback(event):
assert len(calls) == 1
assert calls[0].data['number'] == 23
assert calls[0].context is context


def test_context():
"""Test context init."""
c = ha.Context()
assert c.user_id is None
assert c.parent_id is None
assert c.id is not None

c = ha.Context(23, 100)
assert c.user_id == 23
assert c.parent_id == 100
assert c.id is not None