Skip to content

Commit

Permalink
Merge pull request #659 from snipsco/hotfix/slot-filler-fitting
Browse files Browse the repository at this point in the history
Hotfix/slot filler fitting
  • Loading branch information
ClemDoum authored Sep 6, 2018
2 parents 4fa6b93 + 909749e commit 802acf5
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Changelog
All notable changes to this project will be documented in this file.


## [0.16.5] - 2018-0906
### Fixed
- Segfault in CRFSuite when the `CRFSlotFiller` is fitted only on empty utterances

## [0.16.4] - 2018-08-30
### Fixed
- Issue with the `CrfSlotFiller` file names in the `ProbabilisticIntentParser` serialization
Expand Down Expand Up @@ -135,6 +140,7 @@ several commands.
- Fix compiling issue with `bindgen` dependency when installing from source
- Fix issue in `CRFSlotFiller` when handling builtin entities

[0.16.5]: https://github.com/snipsco/snips-nlu/compare/0.16.4...0.16.5
[0.16.4]: https://github.com/snipsco/snips-nlu/compare/0.16.3...0.16.4
[0.16.3]: https://github.com/snipsco/snips-nlu/compare/0.16.2...0.16.3
[0.16.2]: https://github.com/snipsco/snips-nlu/compare/0.16.1...0.16.2
Expand Down
2 changes: 1 addition & 1 deletion snips_nlu/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
__email__ = "[email protected], [email protected]"
__license__ = "Apache License, Version 2.0"

__version__ = "0.16.4"
__version__ = "0.16.5"
__model_version__ = "0.16.0"

__download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download"
Expand Down
27 changes: 25 additions & 2 deletions snips_nlu/slot_filler/crf_slot_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,17 @@ def fit(self, dataset, intent):
for factory in self.features_factories:
factory.fit(dataset, intent)

# Ensure that X, Y are safe and that the OUTSIDE label is learnt to
# avoid segfault at inference time
# pylint: disable=C0103
X = [self.compute_features(sample[TOKENS], drop_out=True)
for sample in crf_samples]
Y = [[tag for tag in sample[TAGS]] for sample in crf_samples]
X, Y = _ensure_safe(X, Y)

# ensure ascii tags
Y = [[_encode_tag(tag) for tag in sample[TAGS]]
for sample in crf_samples]
Y = [[_encode_tag(tag) for tag in y] for y in Y]

# pylint: enable=C0103
self.crf_model = _get_crf_model(self.config.crf_args)
self.crf_model.fit(X, Y)
Expand Down Expand Up @@ -494,3 +499,21 @@ def _crf_model_from_path(crf_model_path):
f.flush()
crf = CRF(model_filename=f.name)
return crf

# pylint: disable=invalid-name
def _ensure_safe(X, Y):
"""Ensure that Y has at least one not empty label, otherwise the CRF model
does not contain any label and crashes at
Args:
X: features
Y: labels
Returns: (safe_X, safe_Y) a pair of safe features and labels
"""
safe_X = list(X)
safe_Y = list(Y)
if not any(X) or not any(Y):
safe_X.append([""]) # empty feature
safe_Y.append([OUTSIDE]) # outside label
return safe_X, safe_Y
52 changes: 46 additions & 6 deletions snips_nlu/tests/test_crf_slot_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from __future__ import unicode_literals

from builtins import range
from pathlib import Path

from mock import MagicMock
from pathlib import Path
from sklearn_crfsuite import CRF

from snips_nlu.constants import (
DATA, END, ENTITY, ENTITY_KIND, LANGUAGE_EN, RES_MATCH_RANGE, SLOT_NAME,
Expand All @@ -13,10 +13,12 @@
from snips_nlu.pipeline.configs import CRFSlotFillerConfig
from snips_nlu.preprocessing import Token, tokenize
from snips_nlu.result import unresolved_slot
from snips_nlu.slot_filler.crf_slot_filler import (
CRFSlotFiller, _disambiguate_builtin_entities,
_filter_overlapping_builtins, _get_slots_permutations,
_spans_to_tokens_indexes)
from snips_nlu.slot_filler.crf_slot_filler import (CRFSlotFiller,
_disambiguate_builtin_entities,
_ensure_safe,
_filter_overlapping_builtins,
_get_slots_permutations,
_spans_to_tokens_indexes)
from snips_nlu.slot_filler.crf_utils import (
BEGINNING_PREFIX, INSIDE_PREFIX, TaggingScheme)
from snips_nlu.slot_filler.feature_factory import (
Expand Down Expand Up @@ -721,3 +723,41 @@ def test_generate_slots_permutations(self):
"O||O",
}
self.assertSetEqual(expected_permutations, slots_permutations)

def test_should_fit_and_parse_empty_intent(self):
# Given
dataset = {
"intents": {
"dummy_intent": {
"utterances": [
{
"data": [
{
"text": " "
}
]
}
]
}
},
"language": "en",
"entities": dict()
}

slot_filler = CRFSlotFiller()

# When
slot_filler.fit(dataset, "dummy_intent")
slot_filler.get_slots("ya")

def test___ensure_safe(self):
unsafe_examples = [
([[]], [[]]),
([[], []], [[], []]),
]

# We don't assert anything here but it segfault otherwise
for X, Y in unsafe_examples:
X, Y = _ensure_safe(X, Y)
model = CRF().fit(X, Y)
model.predict_single([""])
61 changes: 61 additions & 0 deletions snips_nlu/tests/test_nlu_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,67 @@ def test_nlu_engine_should_raise_error_with_bytes_input(self):
self.assertTrue("Expected unicode but received" in message)


def test_should_fit_and_parse_empty_intent(self):
# Given
dataset = {
"intents": {
"dummy_intent": {
"utterances": [
{
"data": [
{
"text": " "
}
]
}
]
}
},
"language": "en",
"entities": dict()
}

engine = SnipsNLUEngine()

# When / Then
engine.fit(dataset)
engine.parse("ya", intents=["dummy_intent"])

def test_should_fit_and_parse_empty_intent_with_empty_slot(self):
dataset = {
"intents": {
"dummy_intent": {
"utterances": [
{
"data": [
{
"text": " ",
"slot_name": "dummy_slot",
"entity": "dummy_entity"
}
]
}
],
}
},
"entities": {
"dummy_entity": {
"use_synonyms": True,
"automatically_extensible": True,
"parser_threshold": 1.0,
"data": [
{
"value": " ",
"synonyms": []
}
]
}
},
"language": "en",
}



class TestIntentParser1Config(ProcessingUnitConfig):
unit_name = "test_intent_parser1"

Expand Down

0 comments on commit 802acf5

Please sign in to comment.