Skip to content

Commit

Permalink
pass flag instead of determining end-to-end utterance on the fly.
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Nov 23, 2020
1 parent 9c59a56 commit 868a715
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
38 changes: 28 additions & 10 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiohttp

import rasa.core
from rasa.core.policies.policy import PolicyPrediction

from rasa.shared.core import events
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
Expand Down Expand Up @@ -78,14 +79,15 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
]


def action_for_index(
index: int, domain: Domain, action_endpoint: Optional[EndpointConfig]
def action_for_prediction(
prediction: PolicyPrediction,
domain: Domain,
action_endpoint: Optional[EndpointConfig],
) -> "Action":
"""Get an action based on its index in the list of available actions.
"""Gets an instantiated `Action` based on the winning policy's prediction.
Args:
index: The index of the action. This is usually used by `Policy`s as they
predict the action index instead of the name.
prediction: The prediction for the next action.
domain: The `Domain` of the current model. The domain contains the actions
provided by the user + the default actions.
action_endpoint: Can be used to run `custom_actions`
Expand All @@ -95,17 +97,26 @@ def action_for_index(
The instantiated `Action` or `None` if no `Action` was found for the given
index.
"""
index = prediction.max_confidence_index
if domain.num_actions <= index or index < 0:
raise IndexError(
f"Cannot access action at index {index}. "
f"Domain has {domain.num_actions} actions."
)

return action_for_name(domain.action_names[index], domain, action_endpoint)
return action_for_name(
domain.action_names[index],
domain,
action_endpoint,
prediction.is_end_to_end_prediction,
)


def action_for_name(
action_name: Text, domain: Domain, action_endpoint: Optional[EndpointConfig]
action_name: Text,
domain: Domain,
action_endpoint: Optional[EndpointConfig],
is_end_to_end_prediction: bool = False,
) -> "Action":
"""Create an `Action` object based on the name of the `Action`.
Expand All @@ -115,6 +126,7 @@ def action_for_name(
provided by the user + the default actions.
action_endpoint: Can be used to run `custom_actions`
(e.g. using the `rasa-sdk`).
is_end_to_end_prediction: `True` if it is an end-to-end prediction.
Returns:
The instantiated `Action` or `None` if no `Action` was found for the given
Expand All @@ -124,7 +136,9 @@ def action_for_name(
if action_name not in domain.action_names:
domain.raise_action_not_found_exception(action_name)

return action_from_name(action_name, domain, action_endpoint)
return action_from_name(
action_name, domain, action_endpoint, is_end_to_end_prediction
)


def is_retrieval_action(action_name: Text, retrieval_intents: List[Text]) -> bool:
Expand All @@ -148,14 +162,18 @@ def is_retrieval_action(action_name: Text, retrieval_intents: List[Text]) -> boo


def action_from_name(
name: Text, domain: Domain, action_endpoint: Optional[EndpointConfig]
name: Text,
domain: Domain,
action_endpoint: Optional[EndpointConfig],
is_end_to_end_prediction: bool = False,
) -> "Action":
"""Retrieves an action by its name.
Args:
name: The name of the action.
domain: The current model domain.
action_endpoint: The endpoint to execute custom actions.
is_end_to_end_prediction: `True` if it is an end-to-end prediction.
Returns:
The instantiated action.
Expand All @@ -170,7 +188,7 @@ def action_from_name(
):
return ActionRetrieveResponse(name)

if name.startswith(UTTER_PREFIX) or name in domain.action_texts:
if name.startswith(UTTER_PREFIX) or is_end_to_end_prediction:
return ActionUtterTemplate(name)

is_form = name in domain.form_names
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def predict_next_action(
"""
prediction = self._get_next_action_probabilities(tracker)

action = rasa.core.actions.action.action_for_index(
prediction.max_confidence_index, self.domain, self.action_endpoint
action = rasa.core.actions.action.action_for_prediction(
prediction, self.domain, self.action_endpoint
)

logger.debug(
Expand Down
16 changes: 8 additions & 8 deletions tests/core/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def test_fallback_mapping_restart():
tracker, domain, RegexInterpreter()
)
index_of_mapping_policy = 1
next_action = rasa.core.actions.action.action_for_index(
prediction.max_confidence_index, domain, None
next_action = rasa.core.actions.action.action_for_prediction(
prediction, domain, None
)

assert (
Expand Down Expand Up @@ -222,8 +222,8 @@ def test_mapping_wins_over_form(events: List[Event]):
tracker, domain, RegexInterpreter()
)

next_action = rasa.core.actions.action.action_for_index(
prediction.max_confidence_index, domain, None
next_action = rasa.core.actions.action.action_for_prediction(
prediction, domain, None
)

index_of_mapping_policy = 0
Expand Down Expand Up @@ -265,8 +265,8 @@ def test_form_wins_over_everything_else(ensemble: SimplePolicyEnsemble):
tracker, domain, RegexInterpreter()
)

next_action = rasa.core.actions.action.action_for_index(
prediction.max_confidence_index, domain, None
next_action = rasa.core.actions.action.action_for_prediction(
prediction, domain, None
)

index_of_form_policy = 0
Expand All @@ -291,8 +291,8 @@ def test_fallback_wins_over_mapping():
tracker, domain, RegexInterpreter()
)
index_of_fallback_policy = 0
next_action = rasa.core.actions.action.action_for_index(
prediction.max_confidence_index, domain, None
next_action = rasa.core.actions.action.action_for_prediction(
prediction, domain, None
)

assert (
Expand Down
4 changes: 1 addition & 3 deletions tests/core/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ def test_missing_classes_filled_correctly(
new_tracker = DialogueStateTracker(DEFAULT_SENDER_ID, default_domain.slots)
for e in tr.applied_events():
if isinstance(e, ActionExecuted):
new_action = rasa.core.actions.action.action_for_index(
np.random.choice(classes), default_domain, action_endpoint=None
).name()
new_action = default_domain.action_names[np.random.choice(classes)]
new_tracker.update(ActionExecuted(new_action))
else:
new_tracker.update(e)
Expand Down

0 comments on commit 868a715

Please sign in to comment.