Skip to content

Commit 9e48bbb

Browse files
Luke Alvoeirolily-de
Luke Alvoeiro
authored andcommitted
fix: resuming sessions (#35)
1 parent b98d1d9 commit 9e48bbb

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "goose-ai"
33
description = "a programming agent that runs on your machine"
4-
version = "0.8.4"
4+
version = "0.8.5"
55
readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [

src/goose/cli/session.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional
44

5-
from exchange import Message, ToolResult, ToolUse
5+
from exchange import Message, ToolResult, ToolUse, Text
66
from prompt_toolkit.shortcuts import confirm
77
from rich import print
88
from rich.console import RenderableType
@@ -24,6 +24,8 @@
2424
from goose.utils import droid, load_plugins
2525
from goose.utils.session_file import read_from_file, write_to_file
2626

27+
RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
28+
2729

2830
def load_provider() -> str:
2931
# We try to infer a provider, by going in order of what will auth
@@ -91,8 +93,22 @@ def __init__(
9193

9294
if name is not None and self.session_file_path.exists():
9395
messages = self.load_session()
96+
9497
if messages and messages[-1].role == "user":
98+
if type(messages[-1].content[-1]) is Text:
99+
# remove the last user message
100+
messages.pop()
101+
elif type(messages[-1].content[-1]) is ToolResult:
102+
# if we remove this message, we would need to remove
103+
# the previous assistant message as well. instead of doing
104+
# that, we just add a new assistant message to prompt the user
105+
messages.append(Message.assistant(RESUME_MESSAGE))
106+
if messages and type(messages[-1].content[-1]) is ToolUse:
107+
# remove the last request for a tool to be used
95108
messages.pop()
109+
110+
# add a new assistant text message to prompt the user
111+
messages.append(Message.assistant(RESUME_MESSAGE))
96112
self.exchange.messages.extend(messages)
97113

98114
if len(self.exchange.messages) == 0 and plan:

tests/cli/test_session.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unittest.mock import MagicMock, patch
22

33
import pytest
4-
from exchange import Message
4+
from exchange import Message, ToolUse, ToolResult
55
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
66
from goose.cli.prompt.user_input import PromptAction, UserInput
77
from goose.cli.session import Session
@@ -32,7 +32,7 @@ def create_session(session_attributes: dict = {}):
3232
yield create_session
3333

3434

35-
def test_session_does_not_extend_last_user_message_on_init(
35+
def test_session_does_not_extend_last_user_text_message_on_init(
3636
create_session_with_mock_configs, mock_sessions_path, create_session_file
3737
):
3838
messages = [Message.user("Hello"), Message.assistant("Hi"), Message.user("Last should be removed")]
@@ -44,6 +44,41 @@ def test_session_does_not_extend_last_user_message_on_init(
4444
assert [message.text for message in session.exchange.messages] == ["Hello", "Hi"]
4545

4646

47+
def test_session_adds_resume_message_if_last_message_is_tool_result(
48+
create_session_with_mock_configs, mock_sessions_path, create_session_file
49+
):
50+
messages = [
51+
Message.user("Hello"),
52+
Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]),
53+
Message(role="user", content=[ToolResult(tool_use_id="1", output="output")]),
54+
]
55+
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")
56+
57+
session = create_session_with_mock_configs({"name": SESSION_NAME})
58+
print("Messages after session init:", session.exchange.messages) # Debugging line
59+
assert len(session.exchange.messages) == 4
60+
assert session.exchange.messages[-1].role == "assistant"
61+
assert session.exchange.messages[-1].text == "I see we were interrupted. How can I help you?"
62+
63+
64+
def test_session_removes_tool_use_and_adds_resume_message_if_last_message_is_tool_use(
65+
create_session_with_mock_configs, mock_sessions_path, create_session_file
66+
):
67+
messages = [
68+
Message.user("Hello"),
69+
Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]),
70+
]
71+
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")
72+
73+
session = create_session_with_mock_configs({"name": SESSION_NAME})
74+
print("Messages after session init:", session.exchange.messages) # Debugging line
75+
assert len(session.exchange.messages) == 2
76+
assert [message.text for message in session.exchange.messages] == [
77+
"Hello",
78+
"I see we were interrupted. How can I help you?",
79+
]
80+
81+
4782
def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name):
4883
session = create_session_with_mock_configs()
4984
session.exchange.messages.append(Message.assistant("Hello"))

0 commit comments

Comments
 (0)