Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ambiguity_threshold param to FallbackClassifier #6355

Merged
merged 4 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OPEN_UTTERANCE_PREDICTION_KEY,
RESPONSE_SELECTOR_PROPERTY_NAME,
INTENT_RANKING_KEY,
INTENT_NAME_KEY,
)

from rasa.core.events import (
Expand Down Expand Up @@ -722,14 +723,14 @@ async def run(
tracker: "DialogueStateTracker",
domain: "Domain",
) -> List[Event]:
intent_to_affirm = tracker.latest_message.intent.get("name")
intent_to_affirm = tracker.latest_message.intent.get(INTENT_NAME_KEY)

intent_ranking = tracker.latest_message.intent.get(INTENT_RANKING_KEY, [])
if (
intent_to_affirm == DEFAULT_NLU_FALLBACK_INTENT_NAME
and len(intent_ranking) > 1
):
intent_to_affirm = intent_ranking[1]["name"]
intent_to_affirm = intent_ranking[1][INTENT_NAME_KEY]

affirmation_message = f"Did you mean '{intent_to_affirm}'?"

Expand Down
11 changes: 6 additions & 5 deletions rasa/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ruamel.yaml import YAMLError

import rasa.core.constants
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import (
raise_warning,
lazy_property,
Expand Down Expand Up @@ -675,7 +676,7 @@ def get_parsing_states(self, tracker: "DialogueStateTracker") -> Dict[Text, floa
if not latest_message:
return state_dict

intent_name = latest_message.intent.get("name")
intent_name = latest_message.intent.get(INTENT_NAME_KEY)

if intent_name:
for entity_name in self._get_featurized_entities(latest_message):
Expand All @@ -699,18 +700,18 @@ def get_parsing_states(self, tracker: "DialogueStateTracker") -> Dict[Text, floa

if "intent_ranking" in latest_message.parse_data:
for intent in latest_message.parse_data["intent_ranking"]:
if intent.get("name"):
intent_id = "intent_{}".format(intent["name"])
if intent.get(INTENT_NAME_KEY):
intent_id = "intent_{}".format(intent[INTENT_NAME_KEY])
state_dict[intent_id] = intent["confidence"]

elif intent_name:
intent_id = "intent_{}".format(latest_message.intent["name"])
intent_id = "intent_{}".format(latest_message.intent[INTENT_NAME_KEY])
state_dict[intent_id] = latest_message.intent.get("confidence", 1.0)

return state_dict

def _get_featurized_entities(self, latest_message: UserUttered) -> Set[Text]:
intent_name = latest_message.intent.get("name")
intent_name = latest_message.intent.get(INTENT_NAME_KEY)
intent_config = self.intent_config(intent_name)
entities = latest_message.entities
entity_names = {
Expand Down
17 changes: 11 additions & 6 deletions rasa/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
EXTERNAL_MESSAGE_PREFIX,
ACTION_NAME_SENDER_ID_CONNECTOR_STR,
)
from rasa.nlu.constants import INTENT_NAME_KEY

if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -258,7 +259,11 @@ def _from_parse_data(

def __hash__(self) -> int:
return hash(
(self.text, self.intent.get("name"), jsonpickle.encode(self.entities))
(
self.text,
self.intent.get(INTENT_NAME_KEY),
jsonpickle.encode(self.entities),
)
)

def __eq__(self, other) -> bool:
Expand All @@ -267,11 +272,11 @@ def __eq__(self, other) -> bool:
else:
return (
self.text,
self.intent.get("name"),
self.intent.get(INTENT_NAME_KEY),
[jsonpickle.encode(ent) for ent in self.entities],
) == (
other.text,
other.intent.get("name"),
other.intent.get(INTENT_NAME_KEY),
[jsonpickle.encode(ent) for ent in other.entities],
)

Expand Down Expand Up @@ -324,11 +329,11 @@ def as_story_string(self, e2e: bool = False) -> Text:
ent_string = ""

parse_string = "{intent}{entities}".format(
intent=self.intent.get("name", ""), entities=ent_string
intent=self.intent.get(INTENT_NAME_KEY, ""), entities=ent_string
)
if e2e:
message = md_format_message(self.text, self.intent, self.entities)
return "{}: {}".format(self.intent.get("name"), message)
return "{}: {}".format(self.intent.get(INTENT_NAME_KEY), message)
else:
return parse_string
else:
Expand All @@ -344,7 +349,7 @@ def create_external(
) -> "UserUttered":
return UserUttered(
text=f"{EXTERNAL_MESSAGE_PREFIX}{intent_name}",
intent={"name": intent_name},
intent={INTENT_NAME_KEY: intent_name},
metadata={IS_EXTERNAL: True},
entities=entity_list or [],
)
Expand Down
7 changes: 4 additions & 3 deletions rasa/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rasa.core import constants
from rasa.core.trackers import DialogueStateTracker
from rasa.core.constants import INTENT_MESSAGE_PREFIX
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning, class_from_module_path
from rasa.utils.endpoints import EndpointConfig

Expand Down Expand Up @@ -171,8 +172,8 @@ def synchronous_parse(

return {
"text": message_text,
"intent": {"name": intent, "confidence": confidence},
"intent_ranking": [{"name": intent, "confidence": confidence}],
"intent": {INTENT_NAME_KEY: intent, "confidence": confidence},
"intent_ranking": [{INTENT_NAME_KEY: intent, "confidence": confidence}],
"entities": entities,
}

Expand All @@ -195,7 +196,7 @@ async def parse(
Return a default value if the parsing of the text failed."""

default_return = {
"intent": {"name": "", "confidence": 0.0},
"intent": {INTENT_NAME_KEY: "", "confidence": 0.0},
"entities": [],
"text": "",
}
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/mapping_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from rasa.constants import DOCS_URL_POLICIES, DOCS_URL_MIGRATION_GUIDE
import rasa.utils.io
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils import common as common_utils

from rasa.core.actions.action import (
Expand Down Expand Up @@ -108,7 +109,7 @@ def predict_action_probabilities(

result = self._default_predictions(domain)

intent = tracker.latest_message.intent.get("name")
intent = tracker.latest_message.intent.get(INTENT_NAME_KEY)
if intent == USER_INTENT_RESTART:
action = ACTION_RESTART_NAME
elif intent == USER_INTENT_BACK:
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/policies/two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from rasa.core.policies.policy import confidence_scores_for
from rasa.core.trackers import DialogueStateTracker
from rasa.core.constants import FALLBACK_POLICY_PRIORITY
from rasa.nlu.constants import INTENT_NAME_KEY

if typing.TYPE_CHECKING:
from rasa.core.policies.ensemble import PolicyEnsemble
Expand Down Expand Up @@ -121,7 +122,7 @@ def predict_action_probabilities(
"""Predicts the next action if NLU confidence is low."""

nlu_data = tracker.latest_message.parse_data
last_intent_name = nlu_data["intent"].get("name", None)
last_intent_name = nlu_data["intent"].get(INTENT_NAME_KEY, None)
should_nlu_fallback = self.should_nlu_fallback(
nlu_data, tracker.latest_action_name
)
Expand Down
5 changes: 3 additions & 2 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from rasa.core.policies.ensemble import PolicyEnsemble
from rasa.core.tracker_store import TrackerStore
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning
from rasa.utils.endpoints import EndpointConfig

Expand Down Expand Up @@ -436,7 +437,7 @@ def _check_for_unseen_features(self, parse_data: Dict[Text, Any]) -> None:
if not self.domain or self.domain.is_empty():
return

intent = parse_data["intent"]["name"]
intent = parse_data["intent"][INTENT_NAME_KEY]
if intent:
known_intents = self.domain.intents + DEFAULT_INTENTS
if intent not in known_intents:
Expand Down Expand Up @@ -520,7 +521,7 @@ async def _handle_message_with_tracker(
def _should_handle_message(tracker: DialogueStateTracker):
return (
not tracker.is_paused()
or tracker.latest_message.intent.get("name") == USER_INTENT_RESTART
or tracker.latest_message.intent.get(INTENT_NAME_KEY) == USER_INTENT_RESTART
)

def is_action_limit_reached(
Expand Down
5 changes: 4 additions & 1 deletion rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from rasa.core.events import SessionStarted
from rasa.core.trackers import ActionExecuted, DialogueStateTracker, EventVerbosity
import rasa.cli.utils as rasa_cli_utils
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import class_from_module_path, raise_warning, arguments_of
from rasa.utils.endpoints import EndpointConfig
import sqlalchemy as sa
Expand Down Expand Up @@ -910,7 +911,9 @@ def save(self, tracker: DialogueStateTracker) -> None:

for event in events:
data = event.as_dict()
intent = data.get("parse_data", {}).get("intent", {}).get("name")
intent = (
data.get("parse_data", {}).get("intent", {}).get(INTENT_NAME_KEY)
)
action = data.get("name")
timestamp = data.get("timestamp")

Expand Down
35 changes: 23 additions & 12 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from aiohttp import ClientError
from colorclass import Color

from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.nlu.training_data.loading import MARKDOWN, RASA
from sanic import Sanic, response
from sanic.exceptions import NotFound
Expand Down Expand Up @@ -322,12 +323,19 @@ def _selection_choices_from_intent_prediction(
) -> List[Dict[Text, Any]]:
""""Given a list of ML predictions create a UI choice list."""

sorted_intents = sorted(predictions, key=lambda k: (-k["confidence"], k["name"]))
sorted_intents = sorted(
predictions, key=lambda k: (-k["confidence"], k[INTENT_NAME_KEY])
)

choices = []
for p in sorted_intents:
name_with_confidence = f'{p.get("confidence"):03.2f} {p.get("name"):40}'
choice = {"name": name_with_confidence, "value": p.get("name")}
name_with_confidence = (
f'{p.get("confidence"):03.2f} {p.get(INTENT_NAME_KEY):40}'
)
choice = {
INTENT_NAME_KEY: name_with_confidence,
"value": p.get(INTENT_NAME_KEY),
}
choices.append(choice)

return choices
Expand Down Expand Up @@ -416,15 +424,15 @@ async def _request_intent_from_user(

predictions = latest_message.get("parse_data", {}).get("intent_ranking", [])

predicted_intents = {p["name"] for p in predictions}
predicted_intents = {p[INTENT_NAME_KEY] for p in predictions}

for i in intents:
if i not in predicted_intents:
predictions.append({"name": i, "confidence": 0.0})
predictions.append({INTENT_NAME_KEY: i, "confidence": 0.0})

# convert intents to ui list and add <other> as a free text alternative
choices = [
{"name": "<create_new_intent>", "value": OTHER_INTENT}
{INTENT_NAME_KEY: "<create_new_intent>", "value": OTHER_INTENT}
] + _selection_choices_from_intent_prediction(predictions)

intent_name = await _request_selection_from_intents(
Expand All @@ -433,11 +441,12 @@ async def _request_intent_from_user(

if intent_name == OTHER_INTENT:
intent_name = await _request_free_text_intent(conversation_id, endpoint)
selected_intent = {"name": intent_name, "confidence": 1.0}
selected_intent = {INTENT_NAME_KEY: intent_name, "confidence": 1.0}
else:
# returns the selected intent with the original probability value
selected_intent = next(
(x for x in predictions if x["name"] == intent_name), {"name": None}
(x for x in predictions if x[INTENT_NAME_KEY] == intent_name),
{INTENT_NAME_KEY: None},
)

return selected_intent
Expand Down Expand Up @@ -479,7 +488,7 @@ def colored(txt: Text, color: Text) -> Text:

def format_user_msg(user_event: UserUttered, max_width: int) -> Text:
intent = user_event.intent or {}
intent_name = intent.get("name", "")
intent_name = intent.get(INTENT_NAME_KEY, "")
_confidence = intent.get("confidence", 1.0)
_md = _as_md_message(user_event.parse_data)

Expand Down Expand Up @@ -745,7 +754,9 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
if event.get("event") == UserUttered.type_name:
data = event.get("parse_data", {})
rasa_nlu_training_data_utils.remove_untrainable_entities_from(data)
msg = Message.build(data["text"], data["intent"]["name"], data["entities"])
msg = Message.build(
data["text"], data["intent"][INTENT_NAME_KEY], data["entities"]
)
messages.append(msg)
elif event.get("event") == UserUtteranceReverted.type_name and messages:
messages.pop() # user corrected the nlu, remove incorrect example
Expand Down Expand Up @@ -1117,7 +1128,7 @@ def _validate_user_regex(latest_message: Dict[Text, Any], intents: List[Text]) -
`/greet`. Return `True` if the intent is a known one."""

parse_data = latest_message.get("parse_data", {})
intent = parse_data.get("intent", {}).get("name")
intent = parse_data.get("intent", {}).get(INTENT_NAME_KEY)

if intent in intents:
return True
Expand All @@ -1134,7 +1145,7 @@ async def _validate_user_text(

parse_data = latest_message.get("parse_data", {})
text = _as_md_message(parse_data)
intent = parse_data.get("intent", {}).get("name")
intent = parse_data.get("intent", {}).get(INTENT_NAME_KEY)
entities = parse_data.get("entities", [])
if entities:
message = (
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/training/story_reader/markdown_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rasa.core.training.story_reader.story_reader import StoryReader
from rasa.core.training.structures import StoryStep, FORM_PREFIX
from rasa.data import MARKDOWN_FILE_EXTENSION
from rasa.nlu.constants import INTENT_NAME_KEY
from rasa.utils.common import raise_warning

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -209,7 +210,7 @@ async def _parse_message(self, message: Text, line_num: int) -> UserUttered:
utterance = UserUttered(
message, parse_data.get("intent"), parse_data.get("entities"), parse_data
)
intent_name = utterance.intent.get("name")
intent_name = utterance.intent.get(INTENT_NAME_KEY)
if self.domain and intent_name not in self.domain.intents:
raise_warning(
f"Found unknown intent '{intent_name}' on line {line_num}. "
Expand Down
3 changes: 2 additions & 1 deletion rasa/core/training/story_reader/yaml_story_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rasa.core.training.story_reader.story_reader import StoryReader
from rasa.core.training.structures import StoryStep
from rasa.data import YAML_FILE_EXTENSIONS
from rasa.nlu.constants import INTENT_NAME_KEY

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,7 +250,7 @@ def _parse_user_utterance(self, step: Dict[Text, Any]) -> None:
self.current_step_builder.add_user_messages([utterance])

def _validate_that_utterance_is_in_domain(self, utterance: UserUttered) -> None:
intent_name = utterance.intent.get("name")
intent_name = utterance.intent.get(INTENT_NAME_KEY)

if not self.domain:
logger.debug(
Expand Down
Loading