Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
- Reverting all tests that were async from tests/test_train.py back to how it was & all calls of train_nlu without await statement.
- Move vars(args) as the last argument and changed the type so it is the same as the type of the function parameter where it is used.
- Replaced default values of run_nlu_test_async to None.
- Changing train_nlu_async to public & removed coroutine check (main)
- Passing directly the specific arguments to run_nlu_test_async
- Implementing the use of TrainingDataImporter in rasa test (run_evaluation)
  • Loading branch information
Imod7 committed Jan 7, 2021
1 parent 957de84 commit 046fd94
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 45 deletions.
90 changes: 68 additions & 22 deletions rasa/cli/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import os
from typing import List
from typing import List, Optional, Text, Dict, Union, Any

from rasa.cli import SubParsersAction
import rasa.shared.data
Expand All @@ -20,6 +20,7 @@
from rasa.core.test import FAILED_STORIES_FILE
import rasa.shared.utils.validation as validation_utils
import rasa.cli.utils
import rasa.utils.common

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,29 +116,54 @@ def run_core_test(args: argparse.Namespace) -> None:
)


def run_nlu_test(args: argparse.Namespace) -> None:
"""Run NLU tests."""
async def run_nlu_test_async(
config: Optional[Union[Text, List[Text]]],
data_path: Text,
models_path: Text,
output_dir: Text,
cross_validation: bool,
percentages: List[int],
runs: int,
no_errors: bool,
all_args: Dict[Text, Any],
) -> None:
"""Runs NLU tests.
Args:
all_args: all arguments gathered in a Dict so we can pass it as one argument
to other functions.
config: it refers to the model configuration file. It can be a single file or
a list of multiple files or a folder with multiple config files inside.
data_path: path for the nlu data.
models_path: path to a trained Rasa model.
output_dir: output path for any files created during the evaluation.
cross_validation: indicates if it should test the model using cross validation
or not.
percentages: defines the exclusion percentage of the training data.
runs: number of comparison runs to make.
no_errors: indicates if incorrect predictions should be written to a file
or not.
"""
from rasa.test import compare_nlu_models, perform_nlu_cross_validation, test_nlu

nlu_data = rasa.cli.utils.get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH)
nlu_data = rasa.cli.utils.get_validated_path(data_path, "nlu", DEFAULT_DATA_PATH)
nlu_data = rasa.shared.data.get_nlu_directory(nlu_data)
output = args.out or DEFAULT_RESULTS_PATH
args.errors = not args.no_errors

output = output_dir or DEFAULT_RESULTS_PATH
all_args["errors"] = not no_errors
rasa.shared.utils.io.create_directory(output)

if args.config is not None and len(args.config) == 1:
args.config = os.path.abspath(args.config[0])
if os.path.isdir(args.config):
args.config = rasa.shared.utils.io.list_files(args.config)
if config is not None and len(config) == 1:
config = os.path.abspath(config[0])
if os.path.isdir(config):
config = rasa.shared.utils.io.list_files(config)

if isinstance(args.config, list):
if isinstance(config, list):
logger.info(
"Multiple configuration files specified, running nlu comparison mode."
)

config_files = []
for file in args.config:
for file in config:
try:
validation_utils.validate_yaml_schema(
rasa.shared.utils.io.read_file(file), CONFIG_SCHEMA_FILE,
Expand All @@ -148,26 +174,46 @@ def run_nlu_test(args: argparse.Namespace) -> None:
f"Ignoring file '{file}' as it is not a valid config file."
)
continue

compare_nlu_models(
await compare_nlu_models(
configs=config_files,
nlu=nlu_data,
output=output,
runs=args.runs,
exclusion_percentages=args.percentages,
runs=runs,
exclusion_percentages=percentages,
)
elif args.cross_validation:
elif cross_validation:
logger.info("Test model using cross validation.")
config = rasa.cli.utils.get_validated_path(
args.config, "config", DEFAULT_CONFIG_PATH
config, "config", DEFAULT_CONFIG_PATH
)
perform_nlu_cross_validation(config, nlu_data, output, vars(args))
perform_nlu_cross_validation(config, nlu_data, output, all_args)
else:
model_path = rasa.cli.utils.get_validated_path(
args.model, "model", DEFAULT_MODELS_PATH
models_path, "model", DEFAULT_MODELS_PATH
)

test_nlu(model_path, nlu_data, output, vars(args))
await test_nlu(model_path, nlu_data, output, all_args)


def run_nlu_test(args: argparse.Namespace) -> None:
"""Runs NLU tests.
Args:
args: the parsed CLI arguments for 'rasa test nlu'.
"""
rasa.utils.common.run_in_loop(
run_nlu_test_async(
args.config,
args.nlu,
args.model,
args.out,
args.cross_validation,
args.percentages,
args.runs,
args.no_errors,
vars(args),
)
)


def test(args: argparse.Namespace):
Expand Down
16 changes: 9 additions & 7 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from rasa.nlu.components import Component
from rasa.nlu.tokenizers.tokenizer import Token
from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION
from rasa.shared.importers.importer import TrainingDataImporter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1415,7 +1416,7 @@ def remove_pretrained_extractors(pipeline: List[Component]) -> List[Component]:
return pipeline


def run_evaluation(
async def run_evaluation(
data_path: Text,
model_path: Text,
output_directory: Optional[Text] = None,
Expand Down Expand Up @@ -1448,9 +1449,10 @@ def run_evaluation(
interpreter = Interpreter.load(model_path, component_builder)

interpreter.pipeline = remove_pretrained_extractors(interpreter.pipeline)
test_data = rasa.shared.nlu.training_data.loading.load_data(
data_path, interpreter.model_metadata.language
test_data_importer = TrainingDataImporter.load_from_dict(
training_data_paths=[data_path]
)
test_data = await test_data_importer.get_nlu_data()

result: Dict[Text, Optional[Dict]] = {
"intent_evaluation": None,
Expand Down Expand Up @@ -1822,7 +1824,7 @@ def compute_metrics(
)


def compare_nlu(
async def compare_nlu(
configs: List[Text],
data: TrainingData,
exclusion_percentages: List[int],
Expand Down Expand Up @@ -1850,7 +1852,7 @@ def compare_nlu(
Returns: training examples per run
"""

from rasa.train import train_nlu
from rasa.train import train_nlu_async

training_examples_per_run = []

Expand Down Expand Up @@ -1895,7 +1897,7 @@ def compare_nlu(
)

try:
model_path = train_nlu(
model_path = await train_nlu_async(
nlu_config,
train_split_path,
model_output_path,
Expand All @@ -1911,7 +1913,7 @@ def compare_nlu(
model_path = os.path.join(get_model(model_path), "nlu")

output_path = os.path.join(model_output_path, f"{model_name}_report")
result = run_evaluation(
result = await run_evaluation(
test_path, model_path, output_directory=output_path, errors=True
)

Expand Down
2 changes: 1 addition & 1 deletion rasa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ async def _evaluate_model_using_test_set(
model_directory = eval_agent.model_directory
_, nlu_model = model.get_model_subdirectories(model_directory)

return run_evaluation(
return await run_evaluation(
data_path, nlu_model, disable_plotting=True, report_as_dict=True
)

Expand Down
4 changes: 2 additions & 2 deletions rasa/shared/importers/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def load_nlu_importer_from_config(

@staticmethod
def load_from_dict(
config: Optional[Dict],
config_path: Text,
config: Optional[Dict] = None,
config_path: Optional[Text] = None,
domain_path: Optional[Text] = None,
training_data_paths: Optional[List[Text]] = None,
training_type: Optional[TrainingType] = TrainingType.BOTH,
Expand Down
16 changes: 11 additions & 5 deletions rasa/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def test(
additional_arguments = {}

test_core(model, stories, output, additional_arguments)
test_nlu(model, nlu_data, output, additional_arguments)
rasa.utils.common.run_in_loop(
test_nlu(model, nlu_data, output, additional_arguments)
)


def test_core(
Expand All @@ -110,6 +112,7 @@ def test_core(
output: Text = DEFAULT_RESULTS_PATH,
additional_arguments: Optional[Dict] = None,
) -> None:
"""Tests a trained Core model against a set of test stories."""
import rasa.model
from rasa.shared.nlu.interpreter import RegexInterpreter
from rasa.core.agent import Agent
Expand Down Expand Up @@ -154,12 +157,13 @@ def test_core(
rasa.utils.common.run_in_loop(test(stories, _agent, out_directory=output, **kwargs))


def test_nlu(
async def test_nlu(
model: Optional[Text],
nlu_data: Optional[Text],
output_directory: Text = DEFAULT_RESULTS_PATH,
additional_arguments: Optional[Dict] = None,
):
"""Tests the NLU Model."""
from rasa.nlu.test import run_evaluation
from rasa.model import get_model

Expand All @@ -180,15 +184,17 @@ def test_nlu(
kwargs = rasa.shared.utils.common.minimal_kwargs(
additional_arguments, run_evaluation, ["data_path", "model"]
)
run_evaluation(nlu_data, nlu_model, output_directory=output_directory, **kwargs)
await run_evaluation(
nlu_data, nlu_model, output_directory=output_directory, **kwargs
)
else:
rasa.shared.utils.cli.print_error(
"Could not find any model. Use 'rasa train nlu' to train a "
"Rasa model and provide it via the '--model' argument."
)


def compare_nlu_models(
async def compare_nlu_models(
configs: List[Text],
nlu: Text,
output: Text,
Expand All @@ -214,7 +220,7 @@ def compare_nlu_models(
model_name: [[] for _ in range(runs)] for model_name in model_names
}

training_examples_per_run = compare_nlu(
training_examples_per_run = await compare_nlu(
configs,
data,
exclusion_percentages,
Expand Down
5 changes: 3 additions & 2 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def train_nlu(
"""
return rasa.utils.common.run_in_loop(
_train_nlu_async(
train_nlu_async(
config,
nlu_data,
output,
Expand All @@ -713,7 +713,7 @@ def train_nlu(
)


async def _train_nlu_async(
async def train_nlu_async(
config: Text,
nlu_data: Text,
output: Text,
Expand All @@ -725,6 +725,7 @@ async def _train_nlu_async(
model_to_finetune: Optional[Text] = None,
finetuning_epoch_fraction: float = 1.0,
) -> Optional[Text]:
"""Trains an NLU model asynchronously."""
if not nlu_data:
rasa.shared.utils.cli.print_error(
"No NLU data given. Please provide NLU data in order to train "
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import rasa.core.run
from rasa.core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa.model import get_model
from rasa.train import train_async, _train_nlu_async
from rasa.train import train_async, train_nlu_async
from rasa.utils.common import TempDirectoryPath
from tests.core.conftest import (
DEFAULT_DOMAIN_PATH_WITH_SLOTS,
Expand Down Expand Up @@ -223,7 +223,7 @@ async def _train_nlu(
if output_path is None:
output_path = str(tmpdir_factory.mktemp("models"))

return await _train_nlu_async(*args, output=output_path, **kwargs)
return await train_nlu_async(*args, output=output_path, **kwargs)

return _train_nlu

Expand Down
8 changes: 4 additions & 4 deletions tests/nlu/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ def test_drop_intents_below_freq():
assert clean_td.intents == {"affirm", "restaurant_search"}


def test_run_evaluation(unpacked_trained_moodbot_path: Text):
result = run_evaluation(
async def test_run_evaluation(unpacked_trained_moodbot_path: Text):
result = await run_evaluation(
DEFAULT_DATA_PATH,
os.path.join(unpacked_trained_moodbot_path, "nlu"),
errors=False,
Expand Down Expand Up @@ -919,7 +919,7 @@ def test_label_replacement():
assert substitute_labels(original_labels, "O", "no_entity") == target_labels


def test_nlu_comparison(tmp_path: Path):
async def test_nlu_comparison(tmp_path: Path):
config = {
"language": "en",
"pipeline": [
Expand All @@ -933,7 +933,7 @@ def test_nlu_comparison(tmp_path: Path):
configs = [write_file_config(config).name, write_file_config(config).name]

output = str(tmp_path)
compare_nlu_models(
await compare_nlu_models(
configs, DEFAULT_DATA_PATH, output, runs=2, exclusion_percentages=[50, 80]
)

Expand Down

0 comments on commit 046fd94

Please sign in to comment.