Skip to content

Commit

Permalink
Avoid generating number variations when not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed May 10, 2019
1 parent 760c278 commit 1b1b27c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 28 deletions.
26 changes: 20 additions & 6 deletions snips_nlu/dataset/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from snips_nlu.exceptions import DatasetFormatError
from snips_nlu.preprocessing import tokenize_light
from snips_nlu.string_variations import get_string_variations
from snips_nlu.common.dataset_utils import validate_type, validate_key, \
validate_keys
from snips_nlu.common.dataset_utils import (
validate_type, validate_key, validate_keys)

NUMBER_VARIATIONS_THRESHOLD = 1e3


def validate_and_format_dataset(dataset):
Expand Down Expand Up @@ -168,6 +170,12 @@ def _validate_and_format_custom_entity(entity, queries_entities, language,
if s and s not in validated_utterances:
validated_utterances[s] = ent_value

# Number variations in entities values are expensive since each entity
# value is parsed with the builtin entity parser before creating the
# variations. We avoid generating these variations if there's enough entity
# values
number_variations = len(entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD

# Add variations if not colliding
all_original_values = _extract_entity_values(entity)
variations = dict()
Expand All @@ -178,10 +186,13 @@ def _validate_and_format_custom_entity(entity, queries_entities, language,
values_to_variate.update(set(data[SYNONYMS]))
variations[ent_value] = set(
v for value in values_to_variate
for v in get_string_variations(value, language,
builtin_entity_parser))
for v in get_string_variations(
value, language, builtin_entity_parser,
number_variations=number_variations
)
)
variation_counter = Counter(
[v for vars in itervalues(variations) for v in vars])
[v for variations_ in itervalues(variations) for v in variations_])
non_colliding_variations = {
value: [
v for v in variations if
Expand All @@ -197,7 +208,10 @@ def _validate_and_format_custom_entity(entity, queries_entities, language,

# Merge queries entities
queries_entities_variations = {
ent: get_string_variations(ent, language, builtin_entity_parser)
ent: get_string_variations(
ent, language, builtin_entity_parser,
number_variations=number_variations
)
for ent in queries_entities
}
for original_ent, variations in iteritems(queries_entities_variations):
Expand Down
14 changes: 10 additions & 4 deletions snips_nlu/string_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def flatten(results):
return set(i for r in results for i in r)


def get_string_variations(string, language, builtin_entity_parser):
def get_string_variations(string, language, builtin_entity_parser,
number_variations=True):
variations = {string}
variations.update(flatten(case_variations(v) for v in variations))
variations.update(flatten(normalization_variations(v) for v in variations))
Expand All @@ -165,9 +166,14 @@ def get_string_variations(string, language, builtin_entity_parser):
variations.update(flatten(and_variations(v, language) for v in variations))
variations.update(
flatten(punctuation_variations(v, language) for v in variations))
variations.update(
flatten(numbers_variations(v, language, builtin_entity_parser)
for v in variations))

# Special case of number variation which are long to generate due to the
# BuilinEntityParser running on each variation
if number_variations:
variations.update(
flatten(numbers_variations(v, language, builtin_entity_parser)
for v in variations)
)
# Add single space variations
single_space_variations = set(" ".join(v.split()) for v in variations)
variations.update(single_space_variations)
Expand Down
99 changes: 82 additions & 17 deletions snips_nlu/tests/test_dataset_validation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# coding=utf-8
from __future__ import unicode_literals

from builtins import str
from builtins import range, str

from mock import mock
from mock import mock, patch

from snips_nlu.constants import ENTITIES, SNIPS_DATETIME
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.dataset.validation import _validate_and_format_custom_entity
from snips_nlu.exceptions import DatasetFormatError
from snips_nlu.tests.utils import SnipsTest
from snips_nlu.tests.utils import SnipsTest, EntityParserMock


class TestDatasetValidation(SnipsTest):
Expand Down Expand Up @@ -189,8 +190,10 @@ def test_should_format_dataset_by_adding_synonyms(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation.lower(), variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -242,8 +245,10 @@ def test_should_format_dataset_by_adding_entity_values(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation, variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -355,8 +360,10 @@ def test_should_add_missing_reference_entity_values_when_not_use_synonyms(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -496,8 +503,10 @@ def test_should_remove_empty_entities_value_and_empty_synonyms(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation, variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -610,8 +619,10 @@ def test_should_add_capitalize_field(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation, variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -786,8 +797,10 @@ def test_should_normalize_synonyms(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation.lower(), variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -861,8 +874,10 @@ def test_dataset_should_handle_synonyms(
self, mocked_get_string_variations):
# Given
# pylint: disable=unused-argument
def mock_get_string_variations(variation, language,
builtin_entity_parser):
def mock_get_string_variations(
variation, language, builtin_entity_parser,
number_variations
):
return {variation.lower(), variation.title()}

mocked_get_string_variations.side_effect = mock_get_string_variations
Expand Down Expand Up @@ -965,3 +980,53 @@ def test_should_not_avoid_synomyms_variations_collision(self):
"favorïte": "a"
}
self.assertDictEqual(expected_utterances, entity["utterances"])

def test_should_create_number_variation(self):
# Given
num_values = 1
entity = {
"matching_strictness": 1.0,
"use_synonyms": False,
"automatically_extensible": False,
"data": [
{"value": str(i), "synonyms": []}
for i in range(num_values)]
}
builtin_entity_parser = EntityParserMock(dict())

# When
with patch("snips_nlu.dataset.validation"
".get_string_variations") as mocked_string_variations:
mocked_string_variations.return_value = []
_validate_and_format_custom_entity(
entity, [], "en", builtin_entity_parser)
# Then
self.assertGreater(mocked_string_variations.call_count, 0)
for call in mocked_string_variations.mock_calls:
kwargs = call[2]
self.assertTrue(kwargs["number_variations"])

def test_should_not_create_number_variation(self):
# Given
num_values = 1001
entity = {
"matching_strictness": 1.0,
"use_synonyms": False,
"automatically_extensible": False,
"data": [
{"value": str(i), "synonyms": []}
for i in range(num_values)]
}
builtin_entity_parser = EntityParserMock(dict())

# When
with patch("snips_nlu.dataset.validation"
".get_string_variations") as mocked_string_variations:
mocked_string_variations.return_value = []
_validate_and_format_custom_entity(
entity, [], "en", builtin_entity_parser)
# Then
self.assertGreater(mocked_string_variations.call_count, 0)
for call in mocked_string_variations.mock_calls:
kwargs = call[2]
self.assertFalse(kwargs["number_variations"])
16 changes: 16 additions & 0 deletions snips_nlu/tests/test_string_variations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# coding=utf-8
from __future__ import unicode_literals

from mock import MagicMock

from snips_nlu.constants import (LANGUAGE_EN, LANGUAGE_FR, RES_MATCH_RANGE,
SNIPS_NUMBER, START)
from snips_nlu.entity_parser import BuiltinEntityParser
Expand Down Expand Up @@ -164,3 +166,17 @@ def test_numbers_variations_should_handle_floats(self):
"7.62 mm caliber two and 6",
}
self.assertSetEqual(variations, expected_variations)

def test_get_string_variations_should_not_generate_number_variations(self):
# Given
builtin_entity_parser = MagicMock()
mocked_parse = MagicMock(return_value=[])
builtin_entity_parser.parse = mocked_parse

# When/Then
get_string_variations("", "en", builtin_entity_parser,
number_variations=False)
mocked_parse.assert_not_called()
get_string_variations(
"", "en", builtin_entity_parser, number_variations=True)
self.assertGreater(mocked_parse.call_count, 0)
2 changes: 1 addition & 1 deletion snips_nlu/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,4 @@ def from_path(cls, path):
return cls(entities)

def _parse(self, text, scope=None):
return self.entities.get(text)
return self.entities.get(text, [])

0 comments on commit 1b1b27c

Please sign in to comment.