Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add guards to session management #101

Merged
merged 17 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/goose/cli/prompt/overwrite_session_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any

from rich.prompt import Prompt


class OverwriteSessionPrompt(Prompt):
def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
super().__init__(*args, **kwargs)
self.choices = {
"yes": "Overwrite the existing session",
"no": "Pick a new session name",
"resume": "Resume the existing session",
}
self.default = "resume"

def check_choice(self, choice: str) -> bool:
for key in self.choices:
normalized_choice = choice.lower()
if normalized_choice == key or normalized_choice[0] == key[0]:
return True
return False

def pre_prompt(self) -> str:
print("Would you like to overwrite it?")
print()
for key, value in self.choices.items():
first_letter, remaining = key[0], key[1:]
rendered_key = rf"[{first_letter}]{remaining}"
print(f" {rendered_key:10} {value}")
print()
71 changes: 59 additions & 12 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Optional

from exchange import Message, ToolResult, ToolUse, Text
from exchange import Message, Text, ToolResult, ToolUse
from rich import print
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt
from rich.status import Status

from goose.cli.config import ensure_config, session_path, LOG_PATH
from goose._logger import get_logger, setup_logging
from goose.cli.config import LOG_PATH, ensure_config, session_path
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.session_notifier import SessionNotifier
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
from goose.notifier import Notifier
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils._create_exchange import create_exchange
from goose.utils.session_file import read_or_create_file, save_latest_session
from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session

RESUME_MESSAGE = "I see we were interrupted. How can I help you?"

Expand Down Expand Up @@ -60,7 +62,7 @@ def __init__(
profile: Optional[str] = None,
plan: Optional[dict] = None,
log_level: Optional[str] = "INFO",
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> None:
if name is None:
self.name = droid()
Expand All @@ -69,7 +71,7 @@ def __init__(
self.profile_name = profile
self.prompt_session = GoosePromptSession()
self.status_indicator = Status("", spinner="dots")
self.notifier = SessionNotifier(self.status_indicator)
self.notifier = Notifier(self.status_indicator)

self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
Expand All @@ -81,7 +83,7 @@ def __init__(

self.prompt_session = GoosePromptSession()

def _get_initial_messages(self) -> List[Message]:
def _get_initial_messages(self) -> list[Message]:
messages = self.load_session()

if messages and messages[-1].role == "user":
Expand Down Expand Up @@ -151,8 +153,11 @@ def run(self) -> None:
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.
"""
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]")
print(f"[dim]saving to {self.session_file_path}")
if is_existing_session(self.session_file_path):
self._prompt_overwrite_session()

profile_name = self.profile_name or "default"
print(f"[dim]starting session | name: [cyan]{self.name}[/cyan] profile: [cyan]{profile_name}[/cyan][/dim]")
print()
message = self.process_first_message()
while message: # Loop until no input (empty string).
Expand All @@ -178,6 +183,7 @@ def run(self) -> None:
user_input = self.prompt_session.get_user_input()
message = Message.user(text=user_input.text) if user_input.to_continue() else None

self._remove_empty_session()
self._log_cost()

def reply(self) -> None:
Expand Down Expand Up @@ -234,12 +240,53 @@ def interrupt_reply(self) -> None:
def session_file_path(self) -> Path:
return session_path(self.name)

def load_session(self) -> List[Message]:
def load_session(self) -> list[Message]:
return read_or_create_file(self.session_file_path)

def _log_cost(self) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}")
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")

def _prompt_overwrite_session(self) -> None:
print(f"[yellow]Session already exists at {self.session_file_path}.[/]")

choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False)
match choice:
case "y" | "yes":
print("Overwriting existing session")

case "n" | "no":
while True:
new_session_name = Prompt.ask("Enter a new session name")
if not is_existing_session(session_path(new_session_name)):
self.name = new_session_name
break
print(f"[yellow]Session '{new_session_name}' already exists[/]")

case "r" | "resume":
self.exchange.messages.extend(self.load_session())

def _remove_empty_session(self) -> bool:
"""
Removes the session file only when it's empty.

Note: This is because a session file is created at the start of the run
loop. When a user aborts before their first message empty session files
will be created, causing confusion when resuming sessions (which
depends on most recent mtime and is non-empty).

Returns:
bool: True if the session file was removed, False otherwise.
"""
logger = get_logger()
try:
if is_empty_session(self.session_file_path):
logger.debug(f"deleting empty session file: {self.session_file_path}")
self.session_file_path.unlink()
return True
except Exception as e:
logger.error(f"error deleting empty session file: {e}")
return False


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion src/goose/utils/session_file.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import json
import os
from pathlib import Path
import tempfile
from pathlib import Path
from typing import Dict, Iterator, List

from exchange import Message

from goose.cli.config import SESSION_FILE_SUFFIX


def is_existing_session(path: Path) -> bool:
return path.is_file() and path.stat().st_size > 0


def is_empty_session(path: Path) -> bool:
return path.is_file() and path.stat().st_size == 0


def write_to_file(file_path: Path, messages: List[Message]) -> None:
with open(file_path, "w") as f:
_write_messages_to_file(f, messages)
Expand Down
86 changes: 65 additions & 21 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from unittest.mock import MagicMock, patch

import pytest
from exchange import Exchange, Message, ToolUse, ToolResult
from exchange import Exchange, Message, ToolResult, ToolUse
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
Expand All @@ -22,7 +23,7 @@ def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profi
with (
patch("goose.cli.session.create_exchange") as mock_exchange,
patch("goose.cli.session.load_profile", return_value=profile_factory()),
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
patch("goose.cli.session.Notifier") as mock_session_notifier,
patch("goose.cli.session.load_provider", return_value="provider"),
):
mock_session_notifier.return_value = MagicMock()
Expand Down Expand Up @@ -123,36 +124,79 @@ def test_log_log_cost(create_session_with_mock_configs):
mock_logger.info.assert_called_once_with(cost_message)


def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path):
@patch.object(GoosePromptSession, "get_user_input", name="get_user_input")
@patch.object(Exchange, "generate", name="mock_generate")
@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session")
def test_run_should_auto_save_session(
mock_save_latest_session,
mock_generate,
mock_get_user_input,
create_session_with_mock_configs,
mock_sessions_path,
):
def custom_exchange_generate(self, *args, **kwargs):
message = Message.assistant("Response")
self.add(message)
return message

def mock_generate_side_effect(*args, **kwargs):
return custom_exchange_generate(session.exchange, *args, **kwargs)

def save_latest_session(file, messages):
file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages))

user_inputs = [
UserInput(action=PromptAction.CONTINUE, text="Question1"),
UserInput(action=PromptAction.CONTINUE, text="Question2"),
UserInput(action=PromptAction.EXIT),
]

session = create_session_with_mock_configs({"name": SESSION_NAME})
with (
patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs),
patch.object(Exchange, "generate") as mock_generate,
patch("goose.cli.session.save_latest_session") as mock_save_latest_session,
):
mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs)
session.run()

session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
assert session.exchange.generate.call_count == 2
assert mock_save_latest_session.call_count == 2
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
assert session_file.exists()
mock_get_user_input.side_effect = user_inputs
mock_generate.side_effect = mock_generate_side_effect
mock_save_latest_session.side_effect = save_latest_session

session.run()

session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"

assert mock_generate.call_count == 2
assert mock_save_latest_session.call_count == 2
assert mock_save_latest_session.call_args_list[0][0][0] == session_file
assert session_file.exists()

with open(session_file, "r") as f:
saved_messages = [json.loads(line) for line in f]

expected_messages = [
Message.user("Question1"),
Message.assistant("Response"),
Message.user("Question2"),
Message.assistant("Response"),
]

assert len(saved_messages) == len(expected_messages)
for saved, expected in zip(saved_messages, expected_messages):
assert saved["role"] == expected.role
assert saved["content"][0]["text"] == expected.text


@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid")
def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs({"name": None})
assert session.name == "generated_session_name"


@patch("goose.cli.session.is_existing_session", name="mock_is_existing")
@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt")
def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs):
session = create_session_with_mock_configs({"name": SESSION_NAME})

mock_is_existing.return_value = True
session.run()
mock_prompt.assert_called_once()

def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path):
generated_session_name = "generated_session_name"
with patch("goose.cli.session.droid", return_value=generated_session_name):
session = create_session_with_mock_configs({"name": None})
assert session.name == generated_session_name
mock_prompt.reset_mock()
mock_is_existing.return_value = False
session.run()
mock_prompt.assert_not_called()
21 changes: 21 additions & 0 deletions tests/utils/test_session_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from pathlib import Path
from unittest.mock import patch

import pytest
from exchange import Message
from goose.utils.session_file import (
is_empty_session,
list_sorted_session_files,
read_from_file,
read_or_create_file,
Expand Down Expand Up @@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path:
file = file_path / f"{file_name}.jsonl"
file.touch()
return file


@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
@patch("pathlib.Path.stat", name="mock_stat")
def test_is_empty_session(mock_stat, mock_is_file):
mock_stat.return_value.st_size = 0
assert is_empty_session(Path("empty_file.json"))


@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file")
@patch("pathlib.Path.stat", name="mock_stat")
def test_is_not_empty_session(mock_stat, mock_is_file):
mock_stat.return_value.st_size = 100
assert not is_empty_session(Path("non_empty_file.json"))


@patch("pathlib.Path.is_file", return_value=False, name="mock_is_file")
def test_is_not_empty_session_file_not_found(mock_is_file):
assert not is_empty_session(Path("file_not_found.json"))