Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enhance shell() to know when it is interactive #66

Merged
merged 17 commits into from
Sep 24, 2024
105 changes: 87 additions & 18 deletions src/goose/toolkit/developer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from pathlib import Path
from subprocess import CompletedProcess, run
from typing import List, Dict
import os
from goose.utils.check_shell_command import is_dangerous_command
import re
import subprocess
import time
from pathlib import Path
from typing import Dict, List

from exchange import Message
from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language, render_template
from goose.utils.ask import ask_an_ai
from goose.utils.check_shell_command import is_dangerous_command
from rich import box
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text

from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language, render_template


def keep_unsafe_command_prompt(command: str) -> bool:
command_text = Text(command, style="bold red")
Expand Down Expand Up @@ -136,7 +138,7 @@ def read_file(self, path: str) -> str:
@tool
def shell(self, command: str) -> str:
"""
Execute a command on the shell (in OSX)
Execute a command on the shell

This will return the output and error concatenated into a single string, as
you would see from running on the command line. There will also be an indication
Expand All @@ -146,11 +148,7 @@ def shell(self, command: str) -> str:
command (str): The shell command to run. It can support multiline statements
if you need to run more than one at a time
"""
self.notifier.status("planning to run shell command")
# Log the command being executed in a visually structured format (Markdown).
# The `.log` method is used here to log the command execution in the application's UX
# this method is dynamically attached to functions in the Goose framework to handle user-visible
# logging and integrates with the overall UI logging system
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))

if is_dangerous_command(command):
Expand All @@ -159,16 +157,87 @@ def shell(self, command: str) -> str:
if not keep_unsafe_command_prompt(command):
raise RuntimeError(
f"The command {command} was rejected as dangerous by the user."
+ " Do not proceed further, instead ask for instructions."
" Do not proceed further, instead ask for instructions."
)
self.notifier.start()
self.notifier.status("running shell command")
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:
output = "Command succeeded"

# Define patterns that might indicate the process is waiting for input
interaction_patterns = [
r"Do you want to", # Common prompt phrase
r"Enter password", # Password prompt
r"Are you sure", # Confirmation prompt
r"\(y/N\)", # Yes/No prompt
r"Press any key to continue", # Awaiting keypress
r"Waiting for input", # General waiting message
r"\?\s", # Prompts starting with '? '
]
compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in interaction_patterns]

proc = subprocess.Popen(
command,
shell=True,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# this enables us to read lines without blocking
os.set_blocking(proc.stdout.fileno(), False)

# Accumulate the output logs while checking if it might be blocked
output_lines = []
last_output_time = time.time()
cutoff = 10
while proc.poll() is None:
self.notifier.status("running shell command")
line = proc.stdout.readline()
if line:
output_lines.append(line)
last_output_time = time.time()

# If we see a clear pattern match, we plan to abort
exit_criteria = any(pattern.search(line) for pattern in compiled_patterns)

# and if we haven't seen a new line in 10+s, check with AI to see if it may be stuck
if not exit_criteria and time.time() - last_output_time > cutoff:
self.notifier.status("checking on shell status")
response = ask_an_ai(
input="\n".join([command] + output_lines),
prompt=(
"You will evaluate the output of shell commands to see if they may be stuck."
" Look for commands that appear to be awaiting user input, or otherwise running indefinitely (such as a web service)." # noqa
" A command that will take a while, such as downloading resources is okay." # noqa
" return [Yes] if stuck, [No] otherwise."
),
exchange=self.exchange_view.processor,
with_tools=False,
)
exit_criteria = "[yes]" in response.content[0].text.lower()
# We add exponential backoff for how often we check for the command being stuck
cutoff *= 10

if exit_criteria:
proc.terminate()
raise ValueError(
f"The command `{command}` looks like it will run indefinitely or is otherwise stuck."
f"You may be able to specify inputs if it applies to this command."
f"Otherwise to enable continued iteration, you'll need to ask the user to run this command in another terminal." # noqa
)

# read any remaining lines
while line := proc.stdout.readline():
output_lines.append(line)
output = "".join(output_lines)

# Determine the result based on the return code
if proc.returncode == 0:
result = "Command succeeded"
else:
output = f"Command failed with returncode {result.returncode}"
return "\n".join([output, result.stdout, result.stderr])
result = f"Command failed with returncode {proc.returncode}"

# Return the combined result and outputs if we made it this far
return "\n".join([result, output])

@tool
def write_file(self, path: str, content: str) -> str:
Expand Down
11 changes: 10 additions & 1 deletion src/goose/utils/ask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from exchange import Exchange, Message, CheckpointData


def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool = True) -> Message:
def ask_an_ai(
input: str,
exchange: Exchange,
prompt: str = "",
no_history: bool = True,
with_tools: bool = True,
) -> Message:
"""Sends a separate message to an LLM using a separate Exchange than the one underlying the Goose session.

Can be used to summarize a file, or submit any other request that you'd like to an AI. The Exchange can have a
Expand Down Expand Up @@ -36,6 +42,9 @@ def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool
if no_history:
exchange = clear_exchange(exchange)

if not with_tools:
exchange = exchange.replace(tools=())

if prompt:
exchange = replace_prompt(exchange, prompt)

Expand Down