Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename EntityExtractorMixin #11356

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/10225.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rename `EntityExtractorMixin` to `EntityExtractor` due to consistency, as it was discussed in issue 10225.
VitorLamego marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rasa.engine.storage.storage import ModelStorage
from rasa.exceptions import ModelNotFound
from rasa.nlu.constants import TOKENS_NAMES
from rasa.nlu.extractors.extractor import EntityTagSpec, EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityTagSpec, EntityExtractor
import rasa.core.actions.action
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
Expand Down Expand Up @@ -900,7 +900,7 @@ def _create_optional_event_for_entities(
else:
parsed_message = Message(data={TEXT: text})
tokens = parsed_message.get(TOKENS_NAMES[TEXT])
entities = EntityExtractorMixin.convert_predictions_into_entities(
entities = EntityExtractor.convert_predictions_into_entities(
text,
tokens,
predicted_tags,
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.nlu.classifiers.classifier import IntentClassifier
import rasa.shared.utils.io
import rasa.utils.io as io_utils
Expand Down Expand Up @@ -127,7 +127,7 @@
],
is_trainable=True,
)
class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractor):
"""A multi-task model for intent classification and entity extraction.

DIET is Dual Intent and Entity Transformer.
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/classifiers/regex_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.shared.core.domain import Domain
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
YAMLStoryReader,
Expand All @@ -19,7 +19,7 @@
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.INTENT_CLASSIFIER, is_trainable=False
)
class RegexMessageHandler(GraphComponent, EntityExtractorMixin):
class RegexMessageHandler(GraphComponent, EntityExtractor):
"""Handles hardcoded NLU predictions from messages starting with a `/`."""

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/crf_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.test import determine_token_labels
from rasa.nlu.tokenizers.spacy_tokenizer import POS_TAG_KEY
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.nlu.tokenizers.tokenizer import Token, Tokenizer
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
Expand Down Expand Up @@ -85,7 +85,7 @@ class CRFEntityExtractorOptions(str, Enum):
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR, is_trainable=True
)
class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
class CRFEntityExtractor(GraphComponent, EntityExtractor):
"""Implements conditional random fields (CRF) to do named entity recognition."""

CONFIG_FEATURES = "features"
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/duckling_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.constants import DOCS_URL_COMPONENTS
from rasa.shared.nlu.constants import ENTITIES, TEXT
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.shared.nlu.training_data.message import Message
import rasa.shared.utils.io

Expand Down Expand Up @@ -58,7 +58,7 @@ def convert_duckling_format_to_rasa(
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR, is_trainable=False
)
class DucklingEntityExtractor(GraphComponent, EntityExtractorMixin):
class DucklingEntityExtractor(GraphComponent, EntityExtractor):
"""Searches for structured entities, e.g. dates, using a duckling server."""

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/entity_synonyms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.utils import write_json_to_file
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
import rasa.utils.io
import rasa.shared.utils.io
from rasa.engine.storage.resource import Resource
Expand All @@ -22,7 +22,7 @@
@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR, is_trainable=True
)
class EntitySynonymMapper(GraphComponent, EntityExtractorMixin):
class EntitySynonymMapper(GraphComponent, EntityExtractor):
"""Maps entities to their synonyms if they appear in the training data."""

SYNONYM_FILENAME = "synonyms.json"
Expand Down
16 changes: 8 additions & 8 deletions rasa/nlu/extractors/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class EntityTagSpec(NamedTuple):
num_tags: int


class EntityExtractorMixin(abc.ABC):
class EntityExtractor(abc.ABC):
"""Provides functionality for components that do entity extraction.

Inheriting from this class will add utility functions for entity extraction.
Expand Down Expand Up @@ -182,7 +182,7 @@ def convert_predictions_into_entities(
last_token_end = -1

for idx, token in enumerate(tokens):
current_entity_tag = EntityExtractorMixin.get_tag_for(
current_entity_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_TYPE, idx
)

Expand All @@ -191,11 +191,11 @@ def convert_predictions_into_entities(
last_token_end = token.end
continue

current_group_tag = EntityExtractorMixin.get_tag_for(
current_group_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_GROUP, idx
)
current_group_tag = bilou_utils.tag_without_prefix(current_group_tag)
current_role_tag = EntityExtractorMixin.get_tag_for(
current_role_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_ROLE, idx
)
current_role_tag = bilou_utils.tag_without_prefix(current_role_tag)
Expand Down Expand Up @@ -237,7 +237,7 @@ def convert_predictions_into_entities(

if new_tag_found:
# new entity found
entity = EntityExtractorMixin._create_new_entity(
entity = EntityExtractor._create_new_entity(
list(tags.keys()),
current_entity_tag,
current_group_tag,
Expand All @@ -247,7 +247,7 @@ def convert_predictions_into_entities(
confidences,
)
entities.append(entity)
elif EntityExtractorMixin._check_is_single_entity(
elif EntityExtractor._check_is_single_entity(
text, token, last_token_end, split_entities_config, current_entity_tag
):
# current token has the same entity tag as the token before and
Expand All @@ -256,7 +256,7 @@ def convert_predictions_into_entities(
# and a whitespace.
entities[-1][ENTITY_ATTRIBUTE_END] = token.end
if confidences is not None:
EntityExtractorMixin._update_confidence_values(
EntityExtractor._update_confidence_values(
entities, confidences, idx
)

Expand All @@ -265,7 +265,7 @@ def convert_predictions_into_entities(
# tokens are separated by at least 2 symbols (e.g. multiple spaces,
# a comma and a space, etc.) and also shouldn't be represented as a
# single entity
entity = EntityExtractorMixin._create_new_entity(
entity = EntityExtractor._create_new_entity(
list(tags.keys()),
current_entity_tag,
current_group_tag,
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/mitie_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ENTITIES,
)
from rasa.nlu.utils.mitie_utils import MitieModel, MitieNLP
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.shared.nlu.training_data.training_data import TrainingData
from rasa.shared.nlu.training_data.message import Message
import rasa.shared.utils.io
Expand All @@ -36,7 +36,7 @@
is_trainable=True,
model_from="MitieNLP",
)
class MitieEntityExtractor(GraphComponent, EntityExtractorMixin):
class MitieEntityExtractor(GraphComponent, EntityExtractor):
"""A Mitie Entity Extractor (which is a thin wrapper around `Dlib-ml`)."""

MITIE_RESOURCE_FILE = "mitie_ner.dat"
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/regex_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
TEXT,
ENTITY_ATTRIBUTE_TYPE,
)
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor

logger = logging.getLogger(__name__)


@DefaultV1Recipe.register(
DefaultV1Recipe.ComponentType.ENTITY_EXTRACTOR, is_trainable=True
)
class RegexEntityExtractor(GraphComponent, EntityExtractorMixin):
class RegexEntityExtractor(GraphComponent, EntityExtractor):
"""Extracts entities via lookup tables and regexes defined in the training data."""

REGEX_FILE_NAME = "regex.json"
Expand Down
4 changes: 2 additions & 2 deletions rasa/nlu/extractors/spacy_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.nlu.constants import ENTITIES, TEXT
from rasa.nlu.utils.spacy_utils import SpacyModel, SpacyNLP
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.shared.nlu.training_data.message import Message

if typing.TYPE_CHECKING:
Expand All @@ -19,7 +19,7 @@
is_trainable=False,
model_from="SpacyNLP",
)
class SpacyEntityExtractor(GraphComponent, EntityExtractorMixin):
class SpacyEntityExtractor(GraphComponent, EntityExtractor):
"""Entity extractor which uses SpaCy."""

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/nlu/extractors/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from rasa.shared.nlu.constants import TEXT, SPLIT_ENTITIES_BY_COMMA
from rasa.shared.nlu.training_data.message import Message
from rasa.nlu.extractors.extractor import EntityExtractorMixin
from rasa.nlu.extractors.extractor import EntityExtractor
from rasa.nlu.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLReader

Expand Down Expand Up @@ -224,7 +224,7 @@ def test_convert_tags_to_entities(
expected_entities: List[Dict[Text, Any]],
whitespace_tokenizer: WhitespaceTokenizer,
):
extractor = EntityExtractorMixin()
extractor = EntityExtractor()

message = Message(data={TEXT: text})
tokens = whitespace_tokenizer.tokenize(message, TEXT)
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_split_entities_by_comma(
expected_entities: List[Dict[Text, Any]],
whitespace_tokenizer: WhitespaceTokenizer,
):
extractor = EntityExtractorMixin()
extractor = EntityExtractor()

message = Message(data={TEXT: text})
tokens = whitespace_tokenizer.tokenize(message, TEXT)
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_check_correct_entity_annotations(
whitespace_tokenizer.process_training_data(training_data)

with pytest.warns(UserWarning) as record:
EntityExtractorMixin.check_correct_entity_annotations(training_data)
EntityExtractor.check_correct_entity_annotations(training_data)

assert len(record) == warnings
assert all(
Expand Down