Skip to content

Commit

Permalink
Add get_chat_history method to AsyncSimpleOpenai and SimpleOpenai cla…
Browse files Browse the repository at this point in the history
…sses
  • Loading branch information
schleising committed Aug 10, 2024
1 parent 4d01d10 commit 289c8bd
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/simple_openai/async_simple_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,18 @@ async def get_image_url(

# Return the response
return response

def get_chat_history(self, chat_id: str) -> str:
"""Get the chat history
Args:
chat_id (str): The ID of the chat
Returns:
str: The chat history
"""
# Get the chat history
chat_history = self._chat.get_chat(chat_id)

# Return the chat history
return chat_history
24 changes: 24 additions & 0 deletions src/simple_openai/chat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,27 @@ def add_message(

# Return the chat
return chat

def get_chat(self, chat_id: str = DEFAULT_CHAT_ID) -> str:
"""Get the chat
Args:
chat_id (str, optional): The ID of the chat to get. Defaults to DEFAULT_CHAT_ID.
Returns:
str: The chat
"""
# If the chat ID is not in the messages, create a new deque
if chat_id not in self._messages:
return ""

# Get the chat
chat = self._messages[chat_id]

# Parse the chat to a string with each name and message on a new line
chat_str = "\n".join(
[f"{message.name}: {message.content}" for message in chat]
)

# Return the chat
return chat_str
16 changes: 16 additions & 0 deletions src/simple_openai/simple_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,19 @@ def get_image_url(self, prompt: str, style: str = "vivid") -> SimpleOpenaiRespon

# Return the response
return response


def get_chat_history(self, chat_id: str) -> str:
"""Get the chat history
Args:
chat_id (str): The ID of the chat
Returns:
str: The chat history
"""
# Get the chat history
chat_history = self._chat.get_chat(chat_id)

# Return the chat history
return chat_history
16 changes: 16 additions & 0 deletions tests/test_AsyncSimpleOpenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ async def load_and_summarise_chat():
print(f"Group 2 Summary: {response.message}")
print()

def test_chat_history():
# Create the client
client = AsyncSimpleOpenai(api_key, "", Path("storage"))

# Test the chat history
chat = client.get_chat_history("Group 1")
print("Group 1 Chat:")
print(chat)
print()

chat = client.get_chat_history("Group 2")
print("Group 2 Chat:")
print(chat)


async def test_functions():
# Create a system message
Expand Down Expand Up @@ -216,6 +230,8 @@ async def main():
# Test functions
await test_functions()

# Test chat history
test_chat_history()

if __name__ == "__main__":
# Run the main function
Expand Down
18 changes: 18 additions & 0 deletions tests/test_SimpleOpenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ def load_and_summarise_chat():
print()


def test_chat_history():
# Create the client
client = SimpleOpenai(api_key, "", Path("storage"))

# Test the chat history
chat = client.get_chat_history("Group 1")
print("Group 1 Chat:")
print(chat)
print()

chat = client.get_chat_history("Group 2")
print("Group 2 Chat:")
print(chat)


def test_functions():
# Create a system message
system_message = """
Expand Down Expand Up @@ -204,6 +219,9 @@ def main():
# Test functions
test_functions()

# Test chat history
test_chat_history()


if __name__ == "__main__":
# Run the main function
Expand Down

0 comments on commit 289c8bd

Please sign in to comment.