Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_id to events in the processor. #9917

Merged
merged 9 commits into from
Oct 25, 2021
1 change: 1 addition & 0 deletions changelog/8914.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Every conversation event now includes in its metadata the ID of the model which loaded at the time it was created.
4 changes: 3 additions & 1 deletion rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,11 @@ def get_tracker(self, conversation_id: Text) -> DialogueStateTracker:
"""
conversation_id = conversation_id or DEFAULT_SENDER_ID

return self.tracker_store.get_or_create_tracker(
tracker = self.tracker_store.get_or_create_tracker(
conversation_id, append_action_listen=False
)
tracker.model_id = self.model_metadata.model_id
return tracker

def get_trackers_for_all_conversation_sessions(
self, conversation_id: Text
Expand Down
7 changes: 7 additions & 0 deletions rasa/shared/core/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ACTION_TEXT,
ACTION_NAME,
ENTITIES,
METADATA_MODEL_ID,
)
from rasa.shared.core import events
from rasa.shared.core.constants import (
Expand Down Expand Up @@ -221,6 +222,9 @@ def __init__(
self._reset()
self.active_loop: "TrackerActiveLoop" = {}

# Optional model_id to add to all events.
self.model_id: Optional[Text] = None

###
# Public tracker interface
###
Expand Down Expand Up @@ -641,6 +645,9 @@ def update(self, event: Event, domain: Optional[Domain] = None) -> None:
if not isinstance(event, Event): # pragma: no cover
raise ValueError("event to log must be an instance of a subclass of Event.")

if self.model_id and METADATA_MODEL_ID not in event.metadata:
event.metadata = {**event.metadata, METADATA_MODEL_ID: self.model_id}

self.events.append(event)
event.apply_to(self)

Expand Down
1 change: 1 addition & 0 deletions rasa/shared/nlu/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
METADATA = "metadata"
METADATA_INTENT = "intent"
METADATA_EXAMPLE = "example"
METADATA_MODEL_ID = "model_id"
INTENT_RANKING_KEY = "intent_ranking"
PREDICTED_CONFIDENCE_KEY = "confidence"

Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Text, List, Optional, Dict, Any
from unittest.mock import Mock

from rasa.shared.nlu.constants import METADATA_MODEL_ID
import rasa.shared.utils.io
from rasa import server
from rasa.core.agent import Agent, load_agent
Expand All @@ -37,7 +38,7 @@
from rasa.nlu.utils.spacy_utils import SpacyNLP, SpacyModel
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.core.domain import SessionConfig, Domain
from rasa.shared.core.events import UserUttered
from rasa.shared.core.events import Event, UserUttered
from rasa.core.exporter import Exporter

import rasa.core.run
Expand Down Expand Up @@ -772,3 +773,13 @@ def enable_cache(cache_dir: Path):
@pytest.fixture()
def whitespace_tokenizer() -> WhitespaceTokenizer:
return WhitespaceTokenizer(WhitespaceTokenizer.get_default_config())


def with_model_ids(events: List[Event], model_id: Text) -> List[Event]:
return [with_model_id(event, model_id) for event in events]


def with_model_id(event: Event, model_id: Text) -> Event:
new_event = copy.deepcopy(event)
new_event.metadata[METADATA_MODEL_ID] = model_id
return new_event
100 changes: 62 additions & 38 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from rasa.shared.core.domain import Domain
from rasa.shared.constants import INTENT_MESSAGE_PREFIX
from rasa.utils.endpoints import EndpointConfig
from tests.conftest import with_model_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: nit but ideally just import the module and use that to refer to the function (python code conventions)



def model_server_app(model_path: Text, model_hash: Text = "somehash") -> Sanic:
Expand Down Expand Up @@ -226,69 +227,92 @@ async def test_agent_load_on_invalid_model_path(model_path: Optional[Text]):


async def test_agent_handle_message_full_model(default_agent: Agent):
model_id = default_agent.model_id
sender_id = uuid.uuid4().hex
message = UserMessage("hello", sender_id=sender_id)
await default_agent.handle_message(message)
tracker = default_agent.tracker_store.get_or_create_tracker(sender_id)
expected_events = [
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="hello", intent={"name": "greet"}),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted(action_name="utter_greet"),
BotUttered("hey there None!"),
ActionExecuted(action_name="action_listen"),
]
expected_events = with_model_ids(
[
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="hello", intent={"name": "greet"},),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted(action_name="utter_greet"),
BotUttered(
"hey there None!",
{
"elements": None,
"quick_replies": None,
"buttons": None,
"attachment": None,
"image": None,
"custom": None,
},
{"utter_action": "utter_greet"},
),
ActionExecuted(action_name="action_listen"),
],
model_id,
)
assert len(tracker.events) == len(expected_events)
for e1, e2 in zip(tracker.events, expected_events):
assert e1 == e1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops 😁

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how all the tests passed 😁

assert e1 == e2


async def test_agent_handle_message_only_nlu(trained_nlu_model: Text):
agent = await load_agent(model_path=trained_nlu_model)
model_id = agent.model_id
sender_id = uuid.uuid4().hex
message = UserMessage("hello", sender_id=sender_id)
await agent.handle_message(message)
tracker = agent.tracker_store.get_or_create_tracker(sender_id)
expected_events = [
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="hello", intent={"name": "greet"}),
]
expected_events = with_model_ids(
[
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="hello", intent={"name": "greet"},),
],
model_id,
)
assert len(tracker.events) == len(expected_events)
for e1, e2 in zip(tracker.events, expected_events):
assert e1 == e2


async def test_agent_handle_message_only_core(trained_core_model: Text):
agent = await load_agent(model_path=trained_core_model)
model_id = agent.model_id
sender_id = uuid.uuid4().hex
message = UserMessage("/greet", sender_id=sender_id)
await agent.handle_message(message)
tracker = agent.tracker_store.get_or_create_tracker(sender_id)
expected_events = [
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="/greet", intent={"name": "greet"}),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted(action_name="utter_greet"),
BotUttered(
"hey there None!",
{
"elements": None,
"quick_replies": None,
"buttons": None,
"attachment": None,
"image": None,
"custom": None,
},
{"utter_action": "utter_greet"},
),
ActionExecuted(action_name="action_listen"),
]
expected_events = with_model_ids(
[
ActionExecuted(action_name="action_session_start"),
SessionStarted(),
ActionExecuted(action_name="action_listen"),
UserUttered(text="/greet", intent={"name": "greet"},),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted(action_name="utter_greet"),
BotUttered(
"hey there None!",
{
"elements": None,
"quick_replies": None,
"buttons": None,
"attachment": None,
"image": None,
"custom": None,
},
{"utter_action": "utter_greet"},
),
ActionExecuted(action_name="action_listen"),
],
model_id,
)
assert len(tracker.events) == len(expected_events)
for e1, e2 in zip(tracker.events, expected_events):
assert e1 == e2
Expand Down
Loading