From e5bf6a52b6c7eb6a6c1a7c72b16036787fcad491 Mon Sep 17 00:00:00 2001 From: IKostric Date: Fri, 8 Sep 2023 16:21:38 +0200 Subject: [PATCH 1/2] Add dialogue state --- .../dialogue_state_tracker.py | 52 ++++++++++++ .../test_dialogue_state_tracker.py | 82 +++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 dialoguekit/dialogue_manager/dialogue_state_tracker.py create mode 100644 tests/dialogue_manager/test_dialogue_state_tracker.py diff --git a/dialoguekit/dialogue_manager/dialogue_state_tracker.py b/dialoguekit/dialogue_manager/dialogue_state_tracker.py new file mode 100644 index 00000000..2eebec30 --- /dev/null +++ b/dialoguekit/dialogue_manager/dialogue_state_tracker.py @@ -0,0 +1,52 @@ +"""A module for tracking the state of a dialogue.""" + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List + +from dialoguekit.core.annotated_utterance import AnnotatedUtterance +from dialoguekit.core.annotation import Annotation +from dialoguekit.participant.participant import DialogueParticipant + + +@dataclass +class DialogueState: + """A class to represent the state of a dialogue""" + + history: List[AnnotatedUtterance] = field(default_factory=list) + last_user_intent: str = None + slots: Dict[str, List[Annotation]] = field( + default_factory=lambda: defaultdict(list) + ) + turn_count: int = 0 + + +class DialogueStateTracker: + def __init__(self) -> None: + """Initializes the dialogue state tracker""" + self._dialogue_state = DialogueState() + + def get_state(self) -> DialogueState: + """Returns the current state of the dialogue. + + Returns: + The current state of the dialogue. + """ + return self._dialogue_state + + def update(self, annotated_utterance: AnnotatedUtterance) -> None: + """Updates the dialogue state with the annotated utterance. + + Args: + annotated_utterance: The annotated utterance. + """ + self._dialogue_state.history.append(annotated_utterance) + if annotated_utterance.participant is not DialogueParticipant.USER: + return + + self._dialogue_state.last_user_intent = annotated_utterance.intent + + for annotation in annotated_utterance.annotations: + self._dialogue_state.slots[annotation.slot].append(annotation) + + self._dialogue_state.turn_count += 1 diff --git a/tests/dialogue_manager/test_dialogue_state_tracker.py b/tests/dialogue_manager/test_dialogue_state_tracker.py new file mode 100644 index 00000000..9bc046a7 --- /dev/null +++ b/tests/dialogue_manager/test_dialogue_state_tracker.py @@ -0,0 +1,82 @@ +"""Tests for the DialogueStateTracker class.""" + + +import pytest + +from dialoguekit.core.intent import Intent +from dialoguekit.dialogue_manager.dialogue_state_tracker import ( + AnnotatedUtterance, + Annotation, + DialogueParticipant, + DialogueStateTracker, +) + + +@pytest.fixture +def annotated_utterance(): + return AnnotatedUtterance( + "Hello", + DialogueParticipant.USER, + intent=Intent("greeting"), + annotations=[Annotation("name", "John")], + ) + + +def test_initial_state(): + tracker = DialogueStateTracker() + state = tracker.get_state() + assert state.history == [] + assert state.last_user_intent is None + assert state.slots == {} + assert state.turn_count == 0 + + +def test_agent_participant(annotated_utterance: AnnotatedUtterance): + tracker = DialogueStateTracker() + agent_utterance = AnnotatedUtterance( + "Hi, how can I assist you?", + DialogueParticipant.AGENT, + intent=Intent("offer_help"), + annotations=[], + ) + + tracker.update(annotated_utterance) + assert tracker.get_state().last_user_intent == Intent("greeting") + assert tracker.get_state().turn_count == 1 + + tracker.update(agent_utterance) + assert tracker.get_state().last_user_intent == Intent("greeting") + assert tracker.get_state().turn_count == 1 + assert tracker.get_state().history[-1] == agent_utterance + + +def test_update_history(annotated_utterance: AnnotatedUtterance): + tracker = DialogueStateTracker() + tracker.update(annotated_utterance) + assert tracker.get_state().history == [annotated_utterance] + + +def test_update_intent(annotated_utterance: AnnotatedUtterance): + tracker = DialogueStateTracker() + tracker.update(annotated_utterance) + assert tracker.get_state().last_user_intent == Intent("greeting") + + +def test_update_slots(annotated_utterance: AnnotatedUtterance): + tracker = DialogueStateTracker() + tracker.update(annotated_utterance) + assert tracker.get_state().slots == {"name": [Annotation("name", "John")]} + + +def test_turn_count(annotated_utterance: AnnotatedUtterance): + tracker = DialogueStateTracker() + + annotated_utterance_2 = AnnotatedUtterance( + "How are you?", + DialogueParticipant.USER, + Intent("ask_health"), + annotations=[], + ) + tracker.update(annotated_utterance) + tracker.update(annotated_utterance_2) + assert tracker.get_state().turn_count == 2 From 364b52ebce216811fc1c6ce5ba765ef0eda4c194 Mon Sep 17 00:00:00 2001 From: IKostric Date: Tue, 12 Sep 2023 13:05:58 +0200 Subject: [PATCH 2/2] Fix pydocstyle --- .../dialogue_state_tracker.py | 4 +- .../test_dialogue_state_tracker.py | 46 ++++++++++++++++--- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/dialoguekit/dialogue_manager/dialogue_state_tracker.py b/dialoguekit/dialogue_manager/dialogue_state_tracker.py index 2eebec30..e6063b8d 100644 --- a/dialoguekit/dialogue_manager/dialogue_state_tracker.py +++ b/dialoguekit/dialogue_manager/dialogue_state_tracker.py @@ -11,7 +11,7 @@ @dataclass class DialogueState: - """A class to represent the state of a dialogue""" + """A class to represent the state of a dialogue.""" history: List[AnnotatedUtterance] = field(default_factory=list) last_user_intent: str = None @@ -23,7 +23,7 @@ class DialogueState: class DialogueStateTracker: def __init__(self) -> None: - """Initializes the dialogue state tracker""" + """Initializes the dialogue state tracker.""" self._dialogue_state = DialogueState() def get_state(self) -> DialogueState: diff --git a/tests/dialogue_manager/test_dialogue_state_tracker.py b/tests/dialogue_manager/test_dialogue_state_tracker.py index 9bc046a7..3e6f9459 100644 --- a/tests/dialogue_manager/test_dialogue_state_tracker.py +++ b/tests/dialogue_manager/test_dialogue_state_tracker.py @@ -13,7 +13,8 @@ @pytest.fixture -def annotated_utterance(): +def annotated_utterance() -> AnnotatedUtterance: + """Return an annotated utterance.""" return AnnotatedUtterance( "Hello", DialogueParticipant.USER, @@ -22,7 +23,8 @@ def annotated_utterance(): ) -def test_initial_state(): +def test_initial_state() -> None: + """Test that the initial state is correct.""" tracker = DialogueStateTracker() state = tracker.get_state() assert state.history == [] @@ -31,7 +33,13 @@ def test_initial_state(): assert state.turn_count == 0 -def test_agent_participant(annotated_utterance: AnnotatedUtterance): +def test_agent_participant(annotated_utterance: AnnotatedUtterance) -> None: + """Test that the agent participant is updated when the user utterance + contains annotations. + + Args: + annotated_utterance: Annotated utterance. + """ tracker = DialogueStateTracker() agent_utterance = AnnotatedUtterance( "Hi, how can I assist you?", @@ -50,25 +58,49 @@ def test_agent_participant(annotated_utterance: AnnotatedUtterance): assert tracker.get_state().history[-1] == agent_utterance -def test_update_history(annotated_utterance: AnnotatedUtterance): +def test_update_history(annotated_utterance: AnnotatedUtterance) -> None: + """Test that the history is updated when the user utterance contains + annotations. + + Args: + annotated_utterance: Annotated utterance. + """ tracker = DialogueStateTracker() tracker.update(annotated_utterance) assert tracker.get_state().history == [annotated_utterance] -def test_update_intent(annotated_utterance: AnnotatedUtterance): +def test_update_intent(annotated_utterance: AnnotatedUtterance) -> None: + """Test that the last user intent is updated when the user utterance + contains annotations. + + Args: + annotated_utterance: Annotated utterance. + """ tracker = DialogueStateTracker() tracker.update(annotated_utterance) assert tracker.get_state().last_user_intent == Intent("greeting") -def test_update_slots(annotated_utterance: AnnotatedUtterance): +def test_update_slots(annotated_utterance: AnnotatedUtterance) -> None: + """Test that the slots are updated when the user utterance contains + annotations. + + Args: + annotated_utterance: Annotated utterance. + """ tracker = DialogueStateTracker() tracker.update(annotated_utterance) assert tracker.get_state().slots == {"name": [Annotation("name", "John")]} -def test_turn_count(annotated_utterance: AnnotatedUtterance): +def test_turn_count(annotated_utterance: AnnotatedUtterance) -> None: + """Test that the turn count is incremented when the user and agent have + both acted. + + Args: + annotated_utterance: Annotated utterance. + """ tracker = DialogueStateTracker() annotated_utterance_2 = AnnotatedUtterance(