Skip to content

Commit

Permalink
Changing train_nlu_async to public & removed coroutine check (main)
Browse files Browse the repository at this point in the history
  • Loading branch information
Imod7 committed Jan 5, 2021
1 parent 3e48236 commit 76677e2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 31 deletions.
5 changes: 1 addition & 4 deletions rasa/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def main() -> None:
set_log_and_warnings_filters()
rasa.telemetry.initialize_telemetry()
rasa.telemetry.initialize_error_reporting()
if inspect.iscoroutinefunction(cmdline_arguments.func):
rasa.utils.common.run_in_loop(cmdline_arguments.func(cmdline_arguments))
else:
cmdline_arguments.func(cmdline_arguments)
cmdline_arguments.func(cmdline_arguments)
elif hasattr(cmdline_arguments, "version"):
print_version()
else:
Expand Down
46 changes: 26 additions & 20 deletions rasa/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,32 @@ def run_core_test(args: argparse.Namespace) -> None:


async def run_nlu_test_async(
config: Optional[Union[Text, Dict]] = None,
data_path: Optional[Text] = None,
models_path: Optional[Text] = None,
output_dir: Optional[Text] = None,
cross_validation: Optional[bool] = None,
percentages: Optional[List[int]] = None,
runs: Optional[int] = None,
no_errors: Optional[bool] = None,
all_args: Optional[Dict[Text, Any]] = None,
config: Optional[Union[Text, List[Text]]],
data_path: Text,
models_path: Text,
output_dir: Text,
cross_validation: bool,
percentages: List[int],
runs: int,
no_errors: bool,
all_args: Dict[Text, Any],
) -> None:
"""Runs NLU tests.
Args:
all_args: all arguments gathered in a Dict so we can pass them to 'perform_nlu_cross_validation' and
'test_nlu' as parameter.
config: config file or a list of multiple config files.
all_args: all arguments gathered in a Dict so we can pass it as one argument
to other functions.
config: it refers to the model configuration file. It can be a single file or
a list of multiple files or a folder with multiple config files inside.
data_path: path for the nlu data.
models_path: the path for the nlu data.
output_dir: the directory for the results to be saved.
cross_validation: boolean value that indicates if it should test the model using cross validation or not.
models_path: path to a trained Rasa model.
output_dir: output path for any files created during the evaluation.
cross_validation: indicates if it should test the model using cross validation
or not.
percentages: defines the exclusion percentage of the training data.
runs: used in 'compare_nlu_models' and indicates the number of comparison runs.
no_errors: boolean value that indicates if incorrect predictions should be written to a file or not.
runs: number of comparison runs to make.
no_errors: indicates if incorrect predictions should be written to a file
or not.
"""
from rasa.test import compare_nlu_models, perform_nlu_cross_validation, test_nlu

Expand Down Expand Up @@ -193,10 +196,13 @@ async def run_nlu_test_async(


def run_nlu_test(args: argparse.Namespace) -> None:
"""Adding this function layer to be able to run run_nlu_test_async in the event loop.
"""Runs NLU tests.
I have run_nlu_test_async to be able to have await calls inside. That way I can call functions
test_nlu and compare_nlu_models with await statements since they are async functions.
Args:
args: all arguments that were set or omitted in the command line and then
were parsed or populated with their default values respectively.
These arguments define the specific parameters/conditions under which
the NLU tests should run.
"""
rasa.utils.common.run_in_loop(
run_nlu_test_async(
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ async def compare_nlu(
Returns: training examples per run
"""

from rasa.train import train_nlu
from rasa.train import train_nlu_async

training_examples_per_run = []

Expand Down Expand Up @@ -1897,7 +1897,7 @@ async def compare_nlu(
)

try:
model_path = train_nlu(
model_path = await train_nlu_async(
nlu_config,
train_split_path,
model_output_path,
Expand Down
2 changes: 1 addition & 1 deletion rasa/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_core(
output: Text = DEFAULT_RESULTS_PATH,
additional_arguments: Optional[Dict] = None,
) -> None:
"""Checks if models are present and then runs the test function in the event loop."""
"""Tests a trained Core model against a set of test stories."""
import rasa.model
from rasa.shared.nlu.interpreter import RegexInterpreter
from rasa.core.agent import Agent
Expand Down
5 changes: 3 additions & 2 deletions rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def train_nlu(
"""
return rasa.utils.common.run_in_loop(
_train_nlu_async(
train_nlu_async(
config,
nlu_data,
output,
Expand All @@ -713,7 +713,7 @@ def train_nlu(
)


async def _train_nlu_async(
async def train_nlu_async(
config: Text,
nlu_data: Text,
output: Text,
Expand All @@ -725,6 +725,7 @@ async def _train_nlu_async(
model_to_finetune: Optional[Text] = None,
finetuning_epoch_fraction: float = 1.0,
) -> Optional[Text]:
"""Trains an NLU model asynchronously."""
if not nlu_data:
rasa.shared.utils.cli.print_error(
"No NLU data given. Please provide NLU data in order to train "
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import rasa.core.run
from rasa.core.tracker_store import InMemoryTrackerStore, TrackerStore
from rasa.model import get_model
from rasa.train import train_async, _train_nlu_async
from rasa.train import train_async, train_nlu_async
from rasa.utils.common import TempDirectoryPath
from tests.core.conftest import (
DEFAULT_DOMAIN_PATH_WITH_SLOTS,
Expand Down Expand Up @@ -223,7 +223,7 @@ async def _train_nlu(
if output_path is None:
output_path = str(tmpdir_factory.mktemp("models"))

return await _train_nlu_async(*args, output=output_path, **kwargs)
return await train_nlu_async(*args, output=output_path, **kwargs)

return _train_nlu

Expand Down

0 comments on commit 76677e2

Please sign in to comment.