Skip to content

Commit

Permalink
Merge pull request #16 from Vernacular-ai/feat/#13
Browse files Browse the repository at this point in the history
feat: #13
  • Loading branch information
ltbringer authored Apr 13, 2021
2 parents fcb5ba4 + 181a80e commit ac7333b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 22 deletions.
29 changes: 23 additions & 6 deletions dialogy/types/intent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,24 @@ def apply(self, rules: Rule) -> "Intent":
if not rule:
return self

for slot_name, entity_type in rule.items():
self.slots[slot_name] = Slot(name=slot_name, type_=[entity_type], values=[])
for slot_name, entity_types in rule.items():
if isinstance(entity_types, str):
entity_type = entity_types
self.slots[slot_name] = Slot(
name=slot_name, types=[entity_type], values=[]
)
elif isinstance(entity_types, list) and all(
isinstance(type_, str) for type_ in entity_types
):
self.slots[slot_name] = Slot(
name=slot_name, types=entity_types, values=[]
)
else:
raise TypeError(
f"Expected entity_types={entity_types} in the rule"
f" {rule} to be a List[str] but {type(entity_types)} was found."
)

return self

def add_parser(self, postprocessor: PluginFn) -> "Intent":
Expand All @@ -95,16 +111,17 @@ def fill_slot(self, entity: BaseEntity, fill_multiple: bool = False) -> "Intent"
entity (BaseEntity): [entities](../../docs/entity/__init__.html)
"""
log.debug("Looping through slot_names for each entity.")
for slot_name in entity.slot_names:
log.debug("intent slots: %s", self.slots)
for slot_name, slot in self.slots.items():
log.debug("slot_name: %s", slot_name)
log.debug("intent slots: %s", self.slots)
if slot_name in self.slots:
log.debug("slot type: %s", slot.types)
if entity.type in slot.types:
if fill_multiple:
self.slots[slot_name].add(entity)
return self

if not self.slots[slot_name].values:
log.debug("filling %s into %s", entity, self.name)
log.debug("filling %s into %s.", entity, self.name)
self.slots[slot_name].add(entity)
else:
log.debug(
Expand Down
8 changes: 4 additions & 4 deletions dialogy/types/slots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Slot:
- `values` list of entities extracted
"""

def __init__(self, name: str, type_: List[str], values: List[BaseEntity]) -> None:
def __init__(self, name: str, types: List[str], values: List[BaseEntity]) -> None:
self.name = name
self.type = type_
self.types = types
self.values = values

def add(self, entity: BaseEntity) -> "Slot":
Expand All @@ -52,10 +52,10 @@ def json(self) -> Dict[str, Any]:
entities_json = [entity.json() for entity in self.values]
slot_json = {
"name": self.name,
"type": self.type,
"type": self.types,
const.EntityKeys.VALUES: entities_json,
}
return slot_json


Rule = Dict[str, Dict[str, str]]
Rule = Dict[str, Dict[str, Any]]
10 changes: 0 additions & 10 deletions tests/postprocess/text/slot_filler/test_rule_slot_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_1",
values=[{"key": "value"}],
slot_names=["entity_1_slot"],
)

# The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`.
Expand Down Expand Up @@ -93,7 +92,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_2",
values=[{"key": "value"}],
slot_names=["entity_2_slot"],
)

# The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`.
Expand Down Expand Up @@ -133,7 +131,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_1",
values=[{"key": "value"}],
slot_names=["entity_1_slot"],
)

entity_2 = BaseEntity(
Expand All @@ -142,7 +139,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_2",
values=[{"key": "value"}],
slot_names=["entity_2_slot"],
)

# The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`.
Expand Down Expand Up @@ -185,7 +181,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_1",
values=[{"key": "value"}],
slot_names=["entity_1_slot"],
)

entity_2 = BaseEntity(
Expand All @@ -194,7 +189,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_2",
values=[{"key": "value"}],
slot_names=["entity_2_slot"],
)

# The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`.
Expand Down Expand Up @@ -234,7 +228,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_1",
values=[{"key": "value_1"}],
slot_names=["entity_1_slot"],
)

entity_2 = BaseEntity(
Expand All @@ -243,7 +236,6 @@ def access(workflow: Workflow) -> Any:
dim="default",
type="entity_1",
values=[{"key": "value_2"}],
slot_names=["entity_1_slot"],
)

# The RuleBasedSlotFillerPlugin specifies that it expects `Tuple[Intent, List[Entity])` on `access(workflow)`.
Expand Down Expand Up @@ -274,7 +266,6 @@ def test_incorrect_access_fn() -> None:
dim="default",
type="basic",
values=[{"key": "value"}],
slot_names=["basic_slot"],
)

workflow.output = (intent, [entity])
Expand All @@ -298,7 +289,6 @@ def test_missing_access_fn() -> None:
dim="default",
type="basic",
values=[{"key": "value"}],
slot_names=["basic_slot"],
)

workflow.output = (intent, [entity])
Expand Down
40 changes: 38 additions & 2 deletions tests/types/intents/test_intents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for intents
"""
import pytest

from dialogy.types.entity import BaseEntity
from dialogy.types.intent import Intent
Expand Down Expand Up @@ -43,8 +44,8 @@ def test_rule_application() -> None:
assert "date_slot" in intent.slots, "date_slot should be present."
assert "number_slot" in intent.slots, "number_slot should be present."

assert intent.slots["date_slot"].type == ["date"]
assert intent.slots["number_slot"].type == ["number"]
assert intent.slots["date_slot"].types == ["date"]
assert intent.slots["number_slot"].types == ["number"]


def test_missing_rule() -> None:
Expand Down Expand Up @@ -114,3 +115,38 @@ def test_slot_filling() -> None:

intent_json = intent.json()
assert "dim" not in intent_json["slots"]["basic_slot"]["values"][0]


def test_rule_with_multiple_types() -> None:
ordinal_entity = BaseEntity(
range={"from": 0, "to": 15},
body="12th december",
dim="default",
type="ordinal",
values=[{"key": "12th"}],
slot_names=["basic_slot"],
)
number_entity = BaseEntity(
range={"from": 0, "to": 15},
body="12 december",
dim="default",
type="number",
values=[{"key": "12"}],
slot_names=["basic_slot"],
)
rules = {"intent": {"basic_slot": ["ordinal", "number"]}}
intent = Intent(name="intent", score=0.8)
intent.apply(rules)
intent.fill_slot(number_entity, fill_multiple=True)
intent.fill_slot(ordinal_entity, fill_multiple=True)

assert intent.slots["basic_slot"].values[0] == number_entity
assert intent.slots["basic_slot"].values[1] == ordinal_entity


def test_invalid_rule() -> None:
rules = {"intent": {"basic_slot": [12]}}
intent = Intent(name="intent", score=0.8)

with pytest.raises(TypeError):
intent.apply(rules)

0 comments on commit ac7333b

Please sign in to comment.