From 829d513ac464e0421a264fd64d8b94f59a09875e Mon Sep 17 00:00:00 2001 From: Adrien Ball Date: Fri, 9 Aug 2019 14:48:57 +0200 Subject: [PATCH] Allow to fit SnipsNLUEngine with Dataset object (#840) * Allow to fit SnipsNLUEngine with Dataset object * Fix linting annotations * Update Changelog --- CHANGELOG.md | 1 + snips_nlu/dataset/validation.py | 5 +- snips_nlu/tests/test_custom_entity_parser.py | 4 +- snips_nlu/tests/test_dataset_validation.py | 68 +++++++++++++++++++- snips_nlu/tests/utils.py | 2 +- 5 files changed, 73 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b323ae57..074bf33d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. - Allow to bypass the model version check [#830](https://github.com/snipsco/snips-nlu/pull/830) - Persist `CustomEntityParser` license when needed [#832](https://github.com/snipsco/snips-nlu/pull/832) - Document metrics CLI [#839](https://github.com/snipsco/snips-nlu/pull/839) +- Allow to fit SnipsNLUEngine with a `Dataset` object [#840](https://github.com/snipsco/snips-nlu/pull/840) ### Fixed - Invalidate importlib caches after dynamically installing module [#838](https://github.com/snipsco/snips-nlu/pull/838) diff --git a/snips_nlu/dataset/validation.py b/snips_nlu/dataset/validation.py index 2a956768e..c2ac20883 100644 --- a/snips_nlu/dataset/validation.py +++ b/snips_nlu/dataset/validation.py @@ -13,7 +13,7 @@ AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS, UTTERANCES, VALIDATED, VALUE, LICENSE_INFO) -from snips_nlu.dataset import extract_utterance_entities +from snips_nlu.dataset import extract_utterance_entities, Dataset from snips_nlu.entity_parser.builtin_entity_parser import ( BuiltinEntityParser, is_builtin_entity) from snips_nlu.exceptions import DatasetFormatError @@ -32,6 +32,9 @@ def validate_and_format_dataset(dataset): """ from snips_nlu_parsers import get_all_languages + if isinstance(dataset, Dataset): + dataset = dataset.json + # Make this function idempotent if dataset.get(VALIDATED, False): return dataset diff --git a/snips_nlu/tests/test_custom_entity_parser.py b/snips_nlu/tests/test_custom_entity_parser.py index 4399637d8..d61ac225c 100644 --- a/snips_nlu/tests/test_custom_entity_parser.py +++ b/snips_nlu/tests/test_custom_entity_parser.py @@ -348,14 +348,12 @@ def test_create_custom_entity_parser_configuration(self): self.assertDictEqual(expected_dict, config) -# pylint: disable=unused-argument def _persist_parser(path): path = Path(path) with path.open("w", encoding="utf-8") as f: f.write("nothing interesting here") -# pylint: disable=unused-argument def _load_parser(path): path = Path(path) with path.open("r", encoding="utf-8") as f: @@ -365,3 +363,5 @@ def _load_parser(path): # pylint: disable=unused-argument def _stem(string, language): return string[:-1] + +# pylint: enable=unused-argument diff --git a/snips_nlu/tests/test_dataset_validation.py b/snips_nlu/tests/test_dataset_validation.py index a97bfc7ba..1d8a61a59 100644 --- a/snips_nlu/tests/test_dataset_validation.py +++ b/snips_nlu/tests/test_dataset_validation.py @@ -1,14 +1,16 @@ # coding=utf-8 from __future__ import unicode_literals +import io from builtins import range, str from future.utils import iteritems from mock import mock, patch -from snips_nlu.constants import ENTITIES, SNIPS_DATETIME -from snips_nlu.dataset import validate_and_format_dataset -from snips_nlu.dataset.validation import _validate_and_format_custom_entity +from snips_nlu.constants import ENTITIES, SNIPS_DATETIME, VALIDATED +from snips_nlu.dataset import Dataset +from snips_nlu.dataset.validation import ( + validate_and_format_dataset, _validate_and_format_custom_entity) from snips_nlu.exceptions import DatasetFormatError from snips_nlu.tests.utils import SnipsTest, EntityParserMock @@ -1174,3 +1176,63 @@ def test_should_keep_license_info(self): "validated": True } self.assertDictEqual(expected_dataset, validated_dataset) + + def test_validate_should_be_idempotent(self): + # Given + dataset_stream = io.StringIO(""" +# getWeather Intent +--- +type: intent +name: getWeather +utterances: + - what is the weather in [weatherLocation:location](Paris)? + - is it raining in [weatherLocation] [weatherDate:snips/datetime] + +# Location Entity +--- +type: entity +name: location +automatically_extensible: true +values: +- [new york, big apple] +- london + """) + + dataset = Dataset.from_yaml_files("en", [dataset_stream]) + validated_dataset = validate_and_format_dataset(dataset) + + # When + validated_dataset_2 = validate_and_format_dataset(validated_dataset) + + # Then + self.assertDictEqual(validated_dataset, validated_dataset_2) + self.assertTrue(validated_dataset.get(VALIDATED, False)) + + def test_validate_should_accept_dataset_object(self): + # Given + dataset_stream = io.StringIO(""" +# getWeather Intent +--- +type: intent +name: getWeather +utterances: + - what is the weather in [weatherLocation:location](Paris)? + - is it raining in [weatherLocation] [weatherDate:snips/datetime] + +# Location Entity +--- +type: entity +name: location +automatically_extensible: true +values: +- [new york, big apple] +- london + """) + + dataset = Dataset.from_yaml_files("en", [dataset_stream]) + + # When + validated_dataset = validate_and_format_dataset(dataset) + + # Then + self.assertTrue(validated_dataset.get(VALIDATED, False)) diff --git a/snips_nlu/tests/utils.py b/snips_nlu/tests/utils.py index e1792f302..9fc57abab 100644 --- a/snips_nlu/tests/utils.py +++ b/snips_nlu/tests/utils.py @@ -158,7 +158,7 @@ def persist(self, path): f.write(json_string(unit_dict)) @classmethod - def from_path(cls, path, **shared): # pylint:disable=unused-argument + def from_path(cls, path, **_): with (path / "metadata.json").open(encoding="utf8") as f: metadata = json.load(f) fitted = metadata["fitted"]