Skip to content

Commit

Permalink
use constant for training data version everywhere (#10909)
Browse files Browse the repository at this point in the history
* rebase

* fix import of constant name

* fix quoting

* Always use DoubleQuotedScalarString for training data version

* Add another case of DoubleQuotedScalarString

* missing fstrings
  • Loading branch information
indam23 authored Feb 22, 2022
1 parent dfc47bc commit 19439e1
Show file tree
Hide file tree
Showing 27 changed files with 296 additions and 254 deletions.
15 changes: 13 additions & 2 deletions rasa/core/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path
from typing import List, Dict, Text, Any, Tuple, Optional, Union

from ruamel.yaml.scalarstring import DoubleQuotedScalarString

import rasa.shared.utils.io
import rasa.shared.utils.cli
from rasa.shared.constants import REQUIRED_SLOTS_KEY, IGNORED_INTENTS
Expand All @@ -13,6 +15,7 @@
MAPPING_TYPE,
SLOT_MAPPINGS,
)
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.core.domain import KEY_ENTITIES, KEY_SLOTS, KEY_FORMS, Domain
from rasa.shared.exceptions import RasaException

Expand Down Expand Up @@ -172,7 +175,9 @@ def _assemble_new_domain(
elif key == KEY_FORMS:
new_domain.update({key: new_forms})
elif key == "version":
new_domain.update({key: '"3.0"'})
new_domain.update(
{key: DoubleQuotedScalarString(LATEST_TRAINING_DATA_FORMAT_VERSION)}
)
else:
new_domain.update({key: value})
return new_domain
Expand Down Expand Up @@ -226,7 +231,13 @@ def _migrate_domain_files(

if KEY_SLOTS not in original_content and KEY_FORMS not in original_content:
if isinstance(original_content, dict):
original_content.update({"version": '"3.0"'})
original_content.update(
{
"version": DoubleQuotedScalarString(
LATEST_TRAINING_DATA_FORMAT_VERSION
)
}
)

# this is done so that the other domain files can be moved
# in the migrated directory
Expand Down
6 changes: 5 additions & 1 deletion rasa/shared/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Iterable,
)

from ruamel.yaml.scalarstring import DoubleQuotedScalarString

from rasa.shared.constants import (
DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES,
DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION,
Expand Down Expand Up @@ -1608,7 +1610,9 @@ def as_yaml(self, clean_before_dump: bool = False) -> Text:
# thanks to the `should_preserve_key_order` argument
# of `dump_obj_as_yaml_to_string`
domain_data: Dict[Text, Any] = {
KEY_TRAINING_DATA_FORMAT_VERSION: LATEST_TRAINING_DATA_FORMAT_VERSION
KEY_TRAINING_DATA_FORMAT_VERSION: DoubleQuotedScalarString(
LATEST_TRAINING_DATA_FORMAT_VERSION
)
}
if clean_before_dump:
domain_data.update(self.cleaned_domain())
Expand Down
11 changes: 6 additions & 5 deletions tests/cli/test_rasa_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from _pytest.monkeypatch import MonkeyPatch
from _pytest.pytester import RunResult
from rasa.cli import data
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.importers.importer import TrainingDataImporter
from rasa.validator import Validator
import rasa.shared.utils.io
Expand Down Expand Up @@ -156,7 +157,7 @@ def test_validate_files_action_not_found_invalid_domain(
file_name = tmp_path / f"{file_type}.yml"
file_name.write_text(
f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
{file_type}:
- {data_type}: test path
steps:
Expand All @@ -183,7 +184,7 @@ def test_validate_files_form_not_found_invalid_domain(
file_name = tmp_path / f"{file_type}.yml"
file_name.write_text(
f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
{file_type}:
- {data_type}: test path
steps:
Expand Down Expand Up @@ -229,8 +230,8 @@ def test_validate_files_with_active_loop_null(tmp_path: Path):
def test_validate_files_form_slots_not_matching(tmp_path: Path):
domain_file_name = tmp_path / "domain.yml"
domain_file_name.write_text(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
forms:
name_form:
required_slots:
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_validate_files_invalid_slot_mappings(tmp_path: Path):
slot_name = "started_booking_form"
domain.write_text(
f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intents:
- activate_booking
entities:
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_rasa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_test_core_warnings(run_in_simple_project_with_model: Callable[..., RunR
)

simple_test_story_yaml = """
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
stories:
- story: unlikely path
steps:
Expand Down
28 changes: 16 additions & 12 deletions tests/core/actions/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from rasa.core.policies.policy import PolicyPrediction
from rasa.core.actions import action
from rasa.core.actions.action import ActionExecutionRejection, ActionExtractSlots
from rasa.shared.constants import REQUIRED_SLOTS_KEY, IGNORED_INTENTS
from rasa.shared.constants import (
LATEST_TRAINING_DATA_FORMAT_VERSION,
REQUIRED_SLOTS_KEY,
IGNORED_INTENTS,
)
from rasa.shared.core.constants import ACTION_LISTEN_NAME, REQUESTED_SLOT
from rasa.core.actions.forms import FormAction
from rasa.core.channels import CollectingOutputChannel
Expand Down Expand Up @@ -119,7 +123,7 @@ async def test_switch_forms_with_same_slot(default_agent: Agent):
utter_ask_form_2 = f"Please provide the value for {slot_a} of form 2"

domain = f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
nlu:
- intent: order_status
examples: |
Expand Down Expand Up @@ -448,7 +452,7 @@ async def test_validate_slots(
tracker = DialogueStateTracker.from_events(sender_id="bla", evts=events)

domain = f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
entities:
- num_tables
Expand Down Expand Up @@ -722,7 +726,7 @@ def test_temporary_tracker():
sender_id = "test"
domain = Domain.from_yaml(
f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
slots:
{extra_slot}:
type: any
Expand Down Expand Up @@ -1407,8 +1411,8 @@ async def test_extract_other_slots_with_matched_mapping_conditions():

domain = Domain.from_yaml(
textwrap.dedent(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intent:
- greet
- inform
Expand Down Expand Up @@ -1479,8 +1483,8 @@ async def test_extract_other_slots_raises_no_matched_conditions():

domain = Domain.from_yaml(
textwrap.dedent(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intent:
- greet
- inform
Expand Down Expand Up @@ -1549,8 +1553,8 @@ async def test_extract_other_slots_raises_no_matched_conditions():

async def test_action_extract_slots_custom_mapping_with_condition():
domain_yaml = textwrap.dedent(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
slots:
custom_slot:
Expand Down Expand Up @@ -1613,8 +1617,8 @@ async def test_action_extract_slots_custom_mapping_with_condition():
async def test_form_slots_empty_with_restart():
domain = Domain.from_yaml(
textwrap.dedent(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intent:
- greet
- inform
Expand Down
7 changes: 5 additions & 2 deletions tests/core/actions/test_two_stage_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from rasa.core.policies.policy import PolicyPrediction
from rasa.core.processor import MessageProcessor
from rasa.shared.constants import DEFAULT_NLU_FALLBACK_INTENT_NAME
from rasa.shared.constants import (
DEFAULT_NLU_FALLBACK_INTENT_NAME,
LATEST_TRAINING_DATA_FORMAT_VERSION,
)
from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction
from rasa.core.channels import CollectingOutputChannel
from rasa.shared.core.domain import Domain
Expand Down Expand Up @@ -156,7 +159,7 @@ async def test_ask_rephrase_after_failed_affirmation():

domain = Domain.from_yaml(
f"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_ask_rephrase:
- text: {rephrase_text}
Expand Down
31 changes: 16 additions & 15 deletions tests/core/nlg/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from _pytest.logging import LogCaptureFixture

from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
from rasa.shared.core.domain import Domain
from rasa.shared.core.slots import TextSlot, AnySlot, CategoricalSlot, BooleanSlot
from rasa.shared.core.trackers import DialogueStateTracker
Expand Down Expand Up @@ -250,7 +251,7 @@ async def test_nlg_conditional_response_variations_with_diff_slot_types(
async def test_nlg_non_matching_channel():
domain = Domain.from_yaml(
"""
version: "3.0"
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_hi:
- text: "Hello"
Expand All @@ -266,8 +267,8 @@ async def test_nlg_non_matching_channel():

async def test_nlg_conditional_response_variations_with_none_slot():
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "text A"
Expand All @@ -288,8 +289,8 @@ async def test_nlg_conditional_response_variations_with_none_slot():

async def test_nlg_conditional_response_variations_with_slot_not_a_constraint():
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "text A"
Expand All @@ -310,8 +311,8 @@ async def test_nlg_conditional_response_variations_with_slot_not_a_constraint():

async def test_nlg_conditional_response_variations_with_null_slot():
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "text for null"
Expand All @@ -336,8 +337,8 @@ async def test_nlg_conditional_response_variations_with_null_slot():

async def test_nlg_conditional_response_variations_channel_no_condition_met():
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "example with channel"
Expand All @@ -357,8 +358,8 @@ async def test_nlg_conditional_response_variations_channel_no_condition_met():

async def test_nlg_conditional_response_variation_condition_met_channel_mismatch():
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "example with channel"
Expand Down Expand Up @@ -423,8 +424,8 @@ async def test_nlg_conditional_response_variation_condition_met_channel_mismatch
)
async def test_nlg_conditional_edgecases(slots, channel, expected_response):
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "condition example A with channel"
Expand Down Expand Up @@ -466,8 +467,8 @@ async def test_nlg_conditional_response_variations_condition_logging(
caplog: LogCaptureFixture,
):
domain = Domain.from_yaml(
"""
version: "3.0"
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
responses:
utter_action:
- text: "example"
Expand Down
Loading

0 comments on commit 19439e1

Please sign in to comment.