Skip to content

Commit

Permalink
Merge pull request #7716 from RasaHQ/iss-7707
Browse files Browse the repository at this point in the history
Add `split_entities_config` to TEDPolicy
  • Loading branch information
koernerfelicia authored Jan 25, 2021
2 parents 88d97e1 + 4263639 commit 7352104
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 11 deletions.
2 changes: 2 additions & 0 deletions changelog/7707.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add the option to configure whether extracted entities should be split by comma (`","`) or not to TEDPolicy. Fixes
crash when this parameter is accessed during extraction.
30 changes: 30 additions & 0 deletions docs/docs/policies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,27 @@ If you want to fine-tune your model, start by modifying the following parameters
set `weight_sparsity` to 1 as this would result in all kernel weights being 0, i.e. the model is not able
to learn.

* `split_entities_by_comma`:
This parameter defines whether adjacent entities separated by a comma should be treated as one, or split. For example,
entities with the type `ingredients`, like "apple, banana" can be split into "apple" and "banana". An entity with type
`address`, like "Schönhauser Allee 175, 10119 Berlin" should be treated as one.

Can either be
`True`/`False` globally:
```yaml-rasa title="config.yml"
policies:
- name: TEDPolicy
split_entities_by_comma: True
```
or set per entity type, such as:
```yaml-rasa title="config.yml"
policies:
- name: TEDPolicy
split_entities_by_comma:
address: False
ingredients: True
```

The above configuration parameters are the ones you should configure to fit your model to your data.
However, additional parameters exist that can be adapted.

Expand Down Expand Up @@ -320,6 +341,15 @@ However, additional parameters exist that can be adapted.
| entity_recognition | True | If 'True' entity recognition is trained and entities are |
| | | extracted. |
+---------------------------------------+------------------------+--------------------------------------------------------------+
| split_entities_by_comma | True | Splits a list of extracted entities by comma to treat each |
| | | one of them as a single entity. Can either be `True`/`False` |
| | | globally, or set per entity type, such as: |
| | | ``` |
| | | - name: TEDPolicy |
| | | split_entities_by_comma: |
| | | address: True |
| | | ``` |
+---------------------------------------+------------------------+--------------------------------------------------------------+
```
:::note
Expand Down
17 changes: 16 additions & 1 deletion rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
ENTITY_ATTRIBUTE_TYPE,
ENTITY_TAGS,
EXTRACTOR,
SPLIT_ENTITIES_BY_COMMA,
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
)
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
from rasa.core.policies.policy import Policy, PolicyPrediction
Expand Down Expand Up @@ -272,6 +274,10 @@ class TEDPolicy(Policy):
FEATURIZERS: [],
# If set to true, entities are predicted in user utterances.
ENTITY_RECOGNITION: True,
# Split entities by comma, this makes sense e.g. for a list of
# ingredients in a recipe, but it doesn't make sense for the parts of
# an address
SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
}

@staticmethod
Expand All @@ -292,6 +298,11 @@ def __init__(
**kwargs: Any,
) -> None:
"""Declare instance variables with default values."""
self.split_entities_config = rasa.utils.train_utils.init_split_entities(
kwargs.get(SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE),
self.defaults[SPLIT_ENTITIES_BY_COMMA],
)

if not featurizer:
featurizer = self._standard_featurizer(max_history)

Expand Down Expand Up @@ -662,7 +673,11 @@ def _create_optional_event_for_entities(
parsed_message = interpreter.featurize_message(Message(data={TEXT: text}))
tokens = parsed_message.get(TOKENS_NAMES[TEXT])
entities = EntityExtractor.convert_predictions_into_entities(
text, tokens, predicted_tags, confidences=confidence_values
text,
tokens,
predicted_tags,
self.split_entities_config,
confidences=confidence_values,
)

# add the extractor name
Expand Down
46 changes: 37 additions & 9 deletions rasa/nlu/extractors/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,64 @@
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET,
)
import rasa.utils.train_utils


class EntityExtractor(Component):
"""Entity extractors are components which extract entities.
They can be placed in the pipeline like other components, and can extract
entities like a person's name, or a location.
"""

def add_extractor_name(
self, entities: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Adds this extractor's name to a list of entities.
Args:
entities: the extracted entities.
Returns:
the modified entities.
"""
for entity in entities:
entity[EXTRACTOR] = self.name
return entities

def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
"""Adds this extractor's name to the list of processors for this entity.
Args:
entity: the extracted entity and its metadata.
Returns:
the modified entity.
"""
if "processors" in entity:
entity["processors"].append(self.name)
else:
entity["processors"] = [self.name]

return entity

def init_split_entities(self):
"""Initialise the behaviour for splitting entities by comma (or not)."""
def init_split_entities(self) -> Dict[Text, bool]:
"""Initialises the behaviour for splitting entities by comma (or not).
Returns:
Defines desired behaviour for splitting specific entity types and
default behaviour for splitting any entity types for which no
behaviour is defined.
"""
split_entities_config = self.component_config.get(
SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
)
if isinstance(split_entities_config, bool):
split_entities_config = {SPLIT_ENTITIES_BY_COMMA: split_entities_config}
else:
split_entities_config[SPLIT_ENTITIES_BY_COMMA] = self.defaults[
SPLIT_ENTITIES_BY_COMMA
]
return split_entities_config
default_value = self.defaults.get(
SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
)
return rasa.utils.train_utils.init_split_entities(
split_entities_config, default_value
)

@staticmethod
def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list:
Expand Down
27 changes: 26 additions & 1 deletion rasa/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
NUM_TRANSFORMER_LAYERS,
DENSE_DIMENSION,
)
from rasa.shared.nlu.constants import ACTION_NAME, INTENT, ENTITIES
from rasa.shared.nlu.constants import (
ACTION_NAME,
INTENT,
ENTITIES,
SPLIT_ENTITIES_BY_COMMA,
)
from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS
from rasa.core.constants import DIALOGUE

Expand Down Expand Up @@ -335,3 +340,23 @@ def override_defaults(
config[key] = custom[key]

return config


def init_split_entities(
split_entities_config, default_split_entity
) -> Dict[Text, bool]:
"""Initialise the behaviour for splitting entities by comma (or not).
Returns:
Defines desired behaviour for splitting specific entity types and
default behaviour for splitting any entity types for which no behaviour
is defined.
"""
if isinstance(split_entities_config, bool):
# All entities will be split according to `split_entities_config`
split_entities_config = {SPLIT_ENTITIES_BY_COMMA: split_entities_config}
else:
# All entities not named in split_entities_config will be split
# according to `split_entities_config`
split_entities_config[SPLIT_ENTITIES_BY_COMMA] = default_split_entity
return split_entities_config
35 changes: 35 additions & 0 deletions tests/utils/test_train_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Any, Dict

import numpy as np
import pytest

import rasa.utils.train_utils as train_utils
from rasa.nlu.constants import NUMBER_OF_SUB_TOKENS
from rasa.nlu.tokenizers.tokenizer import Token
from rasa.shared.nlu.constants import (
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
SPLIT_ENTITIES_BY_COMMA,
)


def test_align_token_features():
Expand All @@ -26,3 +33,31 @@ def test_align_token_features():
assert np.all(actual_features[0][3] == np.mean(token_features[0][3:5], axis=0))
# embedding is split into 4 sub-tokens
assert np.all(actual_features[0][4] == np.mean(token_features[0][5:10], axis=0))


@pytest.mark.parametrize(
"split_entities_config, expected_initialized_config",
[
(
SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
{SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE},
),
(
{"address": False, "ingredients": True},
{
"address": False,
"ingredients": True,
SPLIT_ENTITIES_BY_COMMA: SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE,
},
),
],
)
def test_init_split_entities_config(
split_entities_config: Any, expected_initialized_config: Dict[(str, bool)],
):
assert (
train_utils.init_split_entities(
split_entities_config, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
)
== expected_initialized_config
)

0 comments on commit 7352104

Please sign in to comment.