Skip to content

Commit

Permalink
Load gazetteer lazily in NgramFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienball committed Mar 13, 2019
1 parent 7886f8d commit 1379636
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
15 changes: 11 additions & 4 deletions snips_nlu/slot_filler/feature_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self, factory_config, **shared):
self.use_stemming = self.args["use_stemming"]
self.common_words_gazetteer_name = self.args[
"common_words_gazetteer_name"]
self.gazetteer = None
self._gazetteer = None
self._language = None
self.language = self.args.get("language_code")

Expand All @@ -226,9 +226,16 @@ def language(self, value):
if value is not None:
self._language = value
self.args["language_code"] = self.language
if self.common_words_gazetteer_name is not None:
self.gazetteer = get_gazetteer(
self.resources, self.common_words_gazetteer_name)

@property
def gazetteer(self):
# Load the gazetteer lazily
if self.common_words_gazetteer_name is None:
return None
if self._gazetteer is None:
self._gazetteer = get_gazetteer(
self.resources, self.common_words_gazetteer_name)
return self._gazetteer

@property
def feature_name(self):
Expand Down
40 changes: 39 additions & 1 deletion snips_nlu/tests/test_crf_slot_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,44 @@ def test_should_not_log_weights_when_not_fitted(self):
with self.assertRaises(NotTrained):
slot_filler.log_weights()

def test_refit(self):
# Given
dataset_stream = io.StringIO("""
---
type: intent
name: my_intent
utterances:
- this is [entity1](my first entity)""")
dataset = Dataset.from_yaml_files("en", [dataset_stream]).json

updated_dataset_stream = io.StringIO("""
---
type: intent
name: my_intent
utterances:
- this is [entity1](my first entity)
- this is [entity1](my first entity) again""")
updated_dataset = Dataset.from_yaml_files(
"en", [updated_dataset_stream]).json

config = CRFSlotFillerConfig(feature_factory_configs=[
{
"args": {
"common_words_gazetteer_name": "top_10000_words_stemmed",
"use_stemming": True,
"n": 1
},
"factory_name": "ngram",
"offsets": [-2, -1, 0, 1, 2]
},
])

# When
slot_filler = CRFSlotFiller(config).fit(dataset, "my_intent")

# Then
slot_filler.fit(updated_dataset, "my_intent")

def test_should_fit_with_naughty_strings_no_tags(self):
# Given
naughty_strings_path = TEST_PATH / "resources" / "naughty_strings.txt"
Expand Down Expand Up @@ -755,7 +793,7 @@ def test_should_fit_and_parse_empty_intent(self):
slot_filler.fit(dataset, "dummy_intent")
slot_filler.get_slots("ya")

def test___ensure_safe(self):
def test_ensure_safe(self):
unsafe_examples = [
([[]], [[]]),
([[], []], [[], []]),
Expand Down

0 comments on commit 1379636

Please sign in to comment.