Skip to content

Commit

Permalink
moving unbounded buffered chat to a different file
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya.kurniawan committed Dec 14, 2024
1 parent e31e100 commit 1942071
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from ._buffered_chat_completion_context import (
BufferedChatCompletionContext,
UnboundedBufferedChatCompletionContext,
)
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_buffered_chat_completion_context import (
UnboundedBufferedChatCompletionContext,
)

__all__ = [
"ChatCompletionContext",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,3 @@ def save_state(self) -> Mapping[str, Any]:
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]


class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
"""A buffered chat completion context that keeps a view of the last n messages,
where n is the buffer size. The buffer size is set at initialization.
Args:
buffer_size (int): The size of the buffer.
"""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, List, Mapping

from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext


class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
"""An unbounded buffered chat completion context that keeps a view of the all the messages."""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]

0 comments on commit 1942071

Please sign in to comment.