Skip to content

Commit

Permalink
Adapt message store api to include ids
Browse files Browse the repository at this point in the history
  • Loading branch information
ggozad committed Sep 5, 2024
1 parent 1ada561 commit e5ea429
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 20 deletions.
6 changes: 6 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

0.5.1 -
------------------

- Add (id) column to message table.
[ggozad]

0.5.0 - 2024-09-04
------------------

Expand Down
6 changes: 4 additions & 2 deletions src/oterm/app/chat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ async def on_submit(self, event: Input.Submitted) -> None:
if not event.value:
return

messages: Sequence[tuple[Author, str]] = await store.get_messages(self.chat_id)
messages: Sequence[tuple[int, Author, str]] = await store.get_messages(
self.chat_id
)
with open(event.value, "w", encoding="utf-8") as file:
for message in messages:
author, text = message
_, author, text = message
file.write(f"*{author.value}*\n")
file.write(f"{text}\n")
file.write("\n---\n")
Expand Down
23 changes: 13 additions & 10 deletions src/oterm/app/widgets/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class ChatContainer(Widget):
ollama = OllamaLLM()
messages: reactive[list[tuple[Author, str]]] = reactive([])
messages: reactive[list[tuple[int, Author, str]]] = reactive([])
chat_name: str
system: str | None
format: Literal["", "json"]
Expand All @@ -52,7 +52,7 @@ def __init__(
db_id: int,
chat_name: str,
model: str = "llama3.1",
messages: list[tuple[Author, str]] = [],
messages: list[tuple[int, Author, str]] = [],
system: str | None = None,
format: Literal["", "json"] = "",
parameters: Options,
Expand All @@ -67,7 +67,7 @@ def __init__(
if author == Author.USER
else {"role": "assistant", "content": message}
)
for author, message in messages
for _, author, message in messages
]

self.ollama = OllamaLLM(
Expand Down Expand Up @@ -95,7 +95,7 @@ async def load_messages(self) -> None:
if self.loaded:
return
message_container = self.query_one("#messageContainer")
for author, message in self.messages:
for _, author, message in self.messages:
chat_item = ChatItem()
chat_item.text = message
chat_item.author = author
Expand All @@ -116,7 +116,6 @@ async def on_submit(self, event: FlexibleInput.Submitted) -> None:

async def response_task() -> None:
input.clear()
self.messages.append((Author.USER, message))
user_chat_item = ChatItem()
user_chat_item.text = message
user_chat_item.author = Author.USER
Expand All @@ -138,21 +137,25 @@ async def response_task() -> None:
response_chat_item.text = text
if message_container.can_view(response_chat_item):
message_container.scroll_end()
self.messages.append((Author.OLLAMA, response))
self.images = []

# Save to db
store = await Store.get_store()
await store.save_message( # type: ignore
id = await store.save_message(
id=None,
chat_id=self.db_id,
author=Author.USER.value,
text=message,
)
await store.save_message( # type: ignore
self.messages.append((id, Author.USER, message))

id = await store.save_message(
id=None,
chat_id=self.db_id,
author=Author.OLLAMA.value,
text=response,
)
self.messages.append((id, Author.OLLAMA, response))
except asyncio.CancelledError:
user_chat_item.remove()
response_chat_item.remove()
Expand Down Expand Up @@ -203,7 +206,7 @@ async def action_edit_chat(self) -> None:
if author == Author.USER
else {"role": "assistant", "content": message}
)
for author, message in self.messages
for _, author, message in self.messages
]

self.ollama = OllamaLLM(
Expand Down Expand Up @@ -238,7 +241,7 @@ def on_history_selected(text: str | None) -> None:
prompt.focus()

prompts = [
message for author, message in self.messages if author == Author.USER
message for _, author, message in self.messages if author == Author.USER
]
prompts.reverse()
screen = PromptHistory(prompts)
Expand Down
6 changes: 3 additions & 3 deletions src/oterm/store/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
-- name: delete_chat
DELETE FROM chat WHERE id = :id;
-- name: save_message
INSERT INTO message(chat_id, author, text)
VALUES(:chat_id, :author, :text);
INSERT OR REPLACE INTO message(id, chat_id, author, text)
VALUES(:id, :chat_id, :author, :text) RETURNING id;
-- name: get_messages
SELECT author, text FROM message WHERE chat_id = :chat_id;
SELECT id, author, text FROM message WHERE chat_id = :chat_id;
"""

queries = aiosql.from_str(chat_sqlite, "aiosqlite")
12 changes: 8 additions & 4 deletions src/oterm/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,23 @@ async def delete_chat(self, id: int) -> None:
await chat_queries.delete_chat(connection, id=id) # type: ignore
await connection.commit()

async def save_message(self, chat_id: int, author: str, text: str) -> None:
async def save_message(
self, id: int | None, chat_id: int, author: str, text: str
) -> int:
async with aiosqlite.connect(self.db_path) as connection:
await chat_queries.save_message( # type: ignore
res = await chat_queries.save_message( # type: ignore
connection,
id=id,
chat_id=chat_id,
author=author,
text=text,
)
await connection.commit()
return res[0][0]

async def get_messages(self, chat_id: int) -> list[tuple[Author, str]]:
async def get_messages(self, chat_id: int) -> list[tuple[int, Author, str]]:

async with aiosqlite.connect(self.db_path) as connection:
messages = await chat_queries.get_messages(connection, chat_id=chat_id) # type: ignore
messages = [(Author(author), text) for author, text in messages]
messages = [(id, Author(author), text) for id, author, text in messages]
return messages
1 change: 0 additions & 1 deletion src/oterm/store/upgrades/v0_5_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


async def add_id_to_messages(db_path: Path) -> None:
print("KAKAKAKAAKAK")
async with aiosqlite.connect(db_path) as connection:
try:
await connection.executescript(
Expand Down

0 comments on commit e5ea429

Please sign in to comment.