Skip to content

Commit

Permalink
test and fix writing YAML stories
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Dec 1, 2020
1 parent 321e937 commit a8d8e04
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 4 deletions.
16 changes: 13 additions & 3 deletions rasa/shared/core/training_data/story_writer/yaml_story_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from ruamel import yaml
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml.scalarstring import DoubleQuotedScalarString
from ruamel.yaml.scalarstring import DoubleQuotedScalarString, LiteralScalarString

import rasa.shared.utils.io
import rasa.shared.core.constants
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
import rasa.shared.core.events
from rasa.shared.core.events import (
UserUttered,
ActionExecuted,
Expand Down Expand Up @@ -216,12 +217,21 @@ def process_user_utterance(
)

if user_utterance.text and (
# We only print the utterance text if it was an end-to-end prediction
user_utterance.use_text_for_featurization
or user_utterance.use_text_for_featurization is None
# or if we want to print a conversation test story.
or is_test_story
):
result[KEY_USER_MESSAGE] = user_utterance.as_story_string()
result[KEY_USER_MESSAGE] = LiteralScalarString(
rasa.shared.core.events.md_format_message(
user_utterance.text,
user_utterance.intent_name,
user_utterance.entities,
)
)

if len(user_utterance.entities):
if len(user_utterance.entities) and not is_test_story:
entities = []
for entity in user_utterance.entities:
if entity["value"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from rasa.shared.core.constants import ACTION_SESSION_START_NAME, ACTION_LISTEN_NAME
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import ActionExecuted, UserUttered
from rasa.shared.core.events import (
ActionExecuted,
UserUttered,
DefinePrevUserUtteredFeaturization,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.training_data.story_reader.markdown_story_reader import (
MarkdownStoryReader,
Expand Down Expand Up @@ -105,6 +109,8 @@ def test_yaml_writer_dumps_user_messages():
- story: default
steps:
- intent: greet
user: |-
Hello
- action: utter_greet
"""
Expand Down Expand Up @@ -177,3 +183,116 @@ def test_yaml_writer_stories_to_yaml(default_domain: Domain):
assert isinstance(result, OrderedDict)
assert "stories" in result
assert len(result["stories"]) == 1


def test_writing_end_to_end_stories(default_domain: Domain):
story_name = "test_writing_end_to_end_stories"
events = [
# Training story story with intent and action labels
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(intent={"name": "greet"}),
ActionExecuted("utter_greet"),
ActionExecuted(ACTION_LISTEN_NAME),
# Prediction story story with intent and action labels
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(text="Hi", intent={"name": "greet"}),
ActionExecuted("utter_greet"),
ActionExecuted(ACTION_LISTEN_NAME),
# End-To-End Training Story
UserUttered(text="Hi"),
DefinePrevUserUtteredFeaturization(use_text_for_featurization=True),
ActionExecuted(action_text="Hi, I'm a bot."),
ActionExecuted(ACTION_LISTEN_NAME),
# End-To-End Prediction Story
UserUttered("Hi", intent={"name": "greet"}),
DefinePrevUserUtteredFeaturization(use_text_for_featurization=True),
ActionExecuted(action_text="Hi, I'm a bot."),
ActionExecuted(ACTION_LISTEN_NAME),
]

tracker = DialogueStateTracker.from_events(story_name, events)
dump = YAMLStoryWriter().dumps(tracker.as_story().story_steps)

assert (
dump.strip()
== textwrap.dedent(
f"""
version: "2.0"
stories:
- story: {story_name}
steps:
- intent: greet
- action: utter_greet
- intent: greet
- action: utter_greet
- user: |-
Hi
- bot: Hi, I'm a bot.
- user: |-
Hi
- bot: Hi, I'm a bot.
"""
).strip()
)


def test_writing_end_to_end_stories_in_test_mode(default_domain: Domain):
story_name = "test_writing_end_to_end_stories_in_test_mode"
events = [
# Conversation tests (end-to-end _testing_)
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(text="Hi", intent={"name": "greet"}),
ActionExecuted("utter_greet"),
ActionExecuted(ACTION_LISTEN_NAME),
# Conversation tests (end-to-end _testing_) and entities
UserUttered(
text="Hi",
intent={"name": "greet"},
entities=[{"value": "Hi", "entity": "test", "start": 0, "end": 2}],
),
ActionExecuted("utter_greet"),
ActionExecuted(ACTION_LISTEN_NAME),
# Conversation test with actual end-to-end utterances
UserUttered(text="Hi", intent={"name": "greet"}),
DefinePrevUserUtteredFeaturization(use_text_for_featurization=True),
ActionExecuted(action_text="Hi, I'm a bot."),
ActionExecuted(ACTION_LISTEN_NAME),
# Conversation test with actual end-to-end utterances
UserUttered(
text="Hi",
intent={"name": "greet"},
entities=[{"value": "Hi", "entity": "test", "start": 0, "end": 2}],
),
DefinePrevUserUtteredFeaturization(use_text_for_featurization=True),
ActionExecuted(action_text="Hi, I'm a bot."),
ActionExecuted(ACTION_LISTEN_NAME),
]

tracker = DialogueStateTracker.from_events(story_name, events)
dump = YAMLStoryWriter().dumps(tracker.as_story().story_steps, is_test_story=True)

assert (
dump.strip()
== textwrap.dedent(
f"""
version: "2.0"
stories:
- story: {story_name}
steps:
- intent: greet
user: |-
Hi
- action: utter_greet
- intent: greet
user: |-
[Hi](test)
- action: utter_greet
- user: |-
Hi
- bot: Hi, I'm a bot.
- user: |-
[Hi](test)
- bot: Hi, I'm a bot.
"""
).strip()
)

0 comments on commit a8d8e04

Please sign in to comment.