Skip to content

Commit

Permalink
Merge branch '2.0.x' of github.com:RasaHQ/rasa into 2.0.x
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Oct 7, 2020
2 parents 16d0b9e + 18668ff commit 3326211
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
16 changes: 10 additions & 6 deletions rasa/core/training/training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Text, List, TYPE_CHECKING, Dict, Set
from collections import defaultdict

from rasa.shared.core.events import ActionExecuted
from rasa.shared.core.events import ActionExecuted, UserUttered
from rasa.shared.core.events import SlotSet, ActiveLoop
from rasa.shared.core.constants import SLOTS, ACTIVE_LOOP

Expand All @@ -24,14 +24,18 @@ def _find_events_after_actions(
"""
events_after_actions = defaultdict(set)

for t in trackers:
tracker = t.init_copy()
for event in t.events:
tracker.update(event)
for tracker in trackers:
action_name = None
for event in tracker.events:
if isinstance(event, ActionExecuted):
action_name = event.action_name or event.action_text
continue
if isinstance(event, UserUttered):
# UserUttered can contain entities that might set some slots, reset
# action_name so that these slots are not attributed to action_listen
action_name = None
continue

action_name = tracker.latest_action_name
if action_name:
events_after_actions[action_name].add(event)

Expand Down
52 changes: 52 additions & 0 deletions tests/core/policies/test_rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,58 @@ def test_incomplete_rules_due_to_slots():
policy.train([complete_rule, fixed_incomplete_rule], domain, RegexInterpreter())


def test_no_incomplete_rules_due_to_slots_after_listen():
some_action = "some_action"
some_slot = "some_slot"
domain = Domain.from_yaml(
f"""
intents:
- {GREET_INTENT_NAME}
actions:
- {some_action}
entities:
- {some_slot}
slots:
{some_slot}:
type: text
"""
)
policy = RulePolicy()
complete_rule = TrackerWithCachedStates.from_events(
"complete_rule",
domain=domain,
slots=domain.slots,
evts=[
ActionExecuted(RULE_SNIPPET_ACTION_NAME),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
intent={"name": GREET_INTENT_NAME},
entities=[{"entity": some_slot, "value": "bla"}],
),
SlotSet(some_slot, "bla"),
ActionExecuted(some_action),
ActionExecuted(ACTION_LISTEN_NAME),
],
is_rule_tracker=True,
)
potentially_incomplete_rule = TrackerWithCachedStates.from_events(
"potentially_incomplete_rule",
domain=domain,
slots=domain.slots,
evts=[
ActionExecuted(RULE_SNIPPET_ACTION_NAME),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(intent={"name": GREET_INTENT_NAME}),
ActionExecuted(some_action),
ActionExecuted(ACTION_LISTEN_NAME),
],
is_rule_tracker=True,
)
policy.train(
[complete_rule, potentially_incomplete_rule], domain, RegexInterpreter()
)


def test_incomplete_rules_due_to_loops():
some_form = "some_form"
domain = Domain.from_yaml(
Expand Down

0 comments on commit 3326211

Please sign in to comment.