diff --git a/snips_nlu/dataset/validation.py b/snips_nlu/dataset/validation.py index 179763ce0..4cf876796 100644 --- a/snips_nlu/dataset/validation.py +++ b/snips_nlu/dataset/validation.py @@ -18,8 +18,10 @@ from snips_nlu.exceptions import DatasetFormatError from snips_nlu.preprocessing import tokenize_light from snips_nlu.string_variations import get_string_variations -from snips_nlu.common.dataset_utils import validate_type, validate_key, \ - validate_keys +from snips_nlu.common.dataset_utils import ( + validate_type, validate_key, validate_keys) + +NUMBER_VARIATIONS_THRESHOLD = 1e3 def validate_and_format_dataset(dataset): @@ -168,6 +170,14 @@ def _validate_and_format_custom_entity(entity, queries_entities, language, if s and s not in validated_utterances: validated_utterances[s] = ent_value + # Number variations in entities values are expensive since each entity + # value is parsed with the builtin entity parser before creating the + # variations. We avoid generating these variations if there's enough entity + # values + number_variations_limit = 0 + if len(entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD: + number_variations_limit = None + # Add variations if not colliding all_original_values = _extract_entity_values(entity) variations = dict() @@ -178,10 +188,13 @@ def _validate_and_format_custom_entity(entity, queries_entities, language, values_to_variate.update(set(data[SYNONYMS])) variations[ent_value] = set( v for value in values_to_variate - for v in get_string_variations(value, language, - builtin_entity_parser)) + for v in get_string_variations( + value, language, builtin_entity_parser, + number_variations_limit=number_variations_limit + ) + ) variation_counter = Counter( - [v for vars in itervalues(variations) for v in vars]) + [v for variations in itervalues(variations) for v in variations]) non_colliding_variations = { value: [ v for v in variations if @@ -197,7 +210,10 @@ def _validate_and_format_custom_entity(entity, queries_entities, language, # Merge queries entities queries_entities_variations = { - ent: get_string_variations(ent, language, builtin_entity_parser) + ent: get_string_variations( + ent, language, builtin_entity_parser, + number_variations_limit=number_variations_limit + ) for ent in queries_entities } for original_ent, variations in iteritems(queries_entities_variations): diff --git a/snips_nlu/string_variations.py b/snips_nlu/string_variations.py index 27ac14065..e0c30ff93 100644 --- a/snips_nlu/string_variations.py +++ b/snips_nlu/string_variations.py @@ -155,7 +155,8 @@ def flatten(results): return set(i for r in results for i in r) -def get_string_variations(string, language, builtin_entity_parser): +def get_string_variations(string, language, builtin_entity_parser, + number_variations_limit=None): variations = {string} variations.update(flatten(case_variations(v) for v in variations)) variations.update(flatten(normalization_variations(v) for v in variations)) @@ -165,9 +166,14 @@ def get_string_variations(string, language, builtin_entity_parser): variations.update(flatten(and_variations(v, language) for v in variations)) variations.update( flatten(punctuation_variations(v, language) for v in variations)) - variations.update( - flatten(numbers_variations(v, language, builtin_entity_parser) - for v in variations)) + + # Special case of number variation which are long to generate due to the + # BuilinEntityParser running on each variation + if (number_variations_limit is None or + len(variations) < number_variations_limit): + variations.update( + flatten(numbers_variations(v, language, builtin_entity_parser) + for v in variations)) # Add single space variations single_space_variations = set(" ".join(v.split()) for v in variations) variations.update(single_space_variations) diff --git a/snips_nlu/tests/test_dataset_validation.py b/snips_nlu/tests/test_dataset_validation.py index ee4981b39..a88c12488 100644 --- a/snips_nlu/tests/test_dataset_validation.py +++ b/snips_nlu/tests/test_dataset_validation.py @@ -1,14 +1,15 @@ # coding=utf-8 from __future__ import unicode_literals -from builtins import str +from builtins import range, str -from mock import mock +from mock import mock, patch from snips_nlu.constants import ENTITIES, SNIPS_DATETIME from snips_nlu.dataset import validate_and_format_dataset +from snips_nlu.dataset.validation import _validate_and_format_custom_entity from snips_nlu.exceptions import DatasetFormatError -from snips_nlu.tests.utils import SnipsTest +from snips_nlu.tests.utils import SnipsTest, EntityParserMock class TestDatasetValidation(SnipsTest): @@ -189,8 +190,10 @@ def test_should_format_dataset_by_adding_synonyms( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation.lower(), variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -242,8 +245,10 @@ def test_should_format_dataset_by_adding_entity_values( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation, variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -355,8 +360,10 @@ def test_should_add_missing_reference_entity_values_when_not_use_synonyms( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -496,8 +503,10 @@ def test_should_remove_empty_entities_value_and_empty_synonyms( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation, variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -610,8 +619,10 @@ def test_should_add_capitalize_field( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation, variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -786,8 +797,10 @@ def test_should_normalize_synonyms( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation.lower(), variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -861,8 +874,10 @@ def test_dataset_should_handle_synonyms( self, mocked_get_string_variations): # Given # pylint: disable=unused-argument - def mock_get_string_variations(variation, language, - builtin_entity_parser): + def mock_get_string_variations( + variation, language, builtin_entity_parser, + number_variations_limit + ): return {variation.lower(), variation.title()} mocked_get_string_variations.side_effect = mock_get_string_variations @@ -965,3 +980,53 @@ def test_should_not_avoid_synomyms_variations_collision(self): "favorïte": "a" } self.assertDictEqual(expected_utterances, entity["utterances"]) + + def test_should_create_number_variation(self): + # Given + num_values = 1 + entity = { + "matching_strictness": 1.0, + "use_synonyms": False, + "automatically_extensible": False, + "data": [ + {"value": str(i), "synonyms": []} + for i in range(num_values)] + } + builtin_entity_parser = EntityParserMock(dict()) + + # When + with patch("snips_nlu.dataset.validation" + ".get_string_variations") as mocked_string_variations: + mocked_string_variations.return_value = [] + _validate_and_format_custom_entity( + entity, [], "en", builtin_entity_parser) + # Then + self.assertGreater(mocked_string_variations.call_count, 0) + for call in mocked_string_variations.mock_calls: + kwargs = call[2] + self.assertIsNone(kwargs["number_variations_limit"]) + + def test_should_not_create_number_variation(self): + # Given + num_values = 10001 + entity = { + "matching_strictness": 1.0, + "use_synonyms": False, + "automatically_extensible": False, + "data": [ + {"value": str(i), "synonyms": []} + for i in range(num_values)] + } + builtin_entity_parser = EntityParserMock(dict()) + + # When + with patch("snips_nlu.dataset.validation" + ".get_string_variations") as mocked_string_variations: + mocked_string_variations.return_value = [] + _validate_and_format_custom_entity( + entity, [], "en", builtin_entity_parser) + # Then + self.assertGreater(mocked_string_variations.call_count, 0) + for call in mocked_string_variations.mock_calls: + kwargs = call[2] + self.assertEqual(0, kwargs["number_variations_limit"]) diff --git a/snips_nlu/tests/test_string_variations.py b/snips_nlu/tests/test_string_variations.py index c07d65cef..6bf8ee87a 100644 --- a/snips_nlu/tests/test_string_variations.py +++ b/snips_nlu/tests/test_string_variations.py @@ -1,6 +1,8 @@ # coding=utf-8 from __future__ import unicode_literals +from mock import MagicMock + from snips_nlu.constants import (LANGUAGE_EN, LANGUAGE_FR, RES_MATCH_RANGE, SNIPS_NUMBER, START) from snips_nlu.entity_parser import BuiltinEntityParser @@ -164,3 +166,19 @@ def test_numbers_variations_should_handle_floats(self): "7.62 mm caliber two and 6", } self.assertSetEqual(variations, expected_variations) + + def test_get_string_variations_should_not_generate_number_variations(self): + # Given + number_variations_limit = 0 + + builtin_entity_parser = MagicMock() + mocked_parse = MagicMock(return_value=[]) + builtin_entity_parser.parse = mocked_parse + + # When/Then + get_string_variations("", "en", builtin_entity_parser, + number_variations_limit=number_variations_limit) + mocked_parse.assert_not_called() + get_string_variations( + "", "en", builtin_entity_parser, number_variations_limit=None) + self.assertGreater(mocked_parse.call_count, 0) diff --git a/snips_nlu/tests/utils.py b/snips_nlu/tests/utils.py index e7d91fd04..e2b41a94e 100644 --- a/snips_nlu/tests/utils.py +++ b/snips_nlu/tests/utils.py @@ -221,4 +221,4 @@ def from_path(cls, path): return cls(entities) def _parse(self, text, scope=None): - return self.entities.get(text) + return self.entities.get(text, [])