Skip to content

Commit

Permalink
Feat: Add simple check exit code functionality to command, action (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret authored Nov 5, 2024
1 parent 32b974a commit 6b274e4
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 70 deletions.
38 changes: 33 additions & 5 deletions src/swerex/runtime/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,21 @@ class BashAction(BaseModel):
interactive program, set this to False.
"""

check: bool = False
"""Whether to check for the exit code. If True, we will raise a
`NonZeroExitCodeError` if the command has a non-zero exit code.
"""

error_msg: str = ""
"""This error message will be used in the `NonZeroExitCodeError` if the
command has a non-zero exit code and `check` is True.
"""

expect: list[str] = []
"""Outputs to expect in addition to the PS1"""

session_type: Literal["bash"] = "bash"
"""Used for type discrimination. Do not change."""


Action = Annotated[BashAction, Field(discriminator="session_type")]
Expand Down Expand Up @@ -124,6 +135,16 @@ class Command(BaseModel):
shell: bool = False
"""Same as the `subprocess.run()` `shell` argument."""

check: bool = False
"""Whether to check for the exit code. If True, we will raise a
`CommandFailedError` if the command fails.
"""

error_msg: str = ""
"""This error message will be used in the `NonZeroExitCodeError` if the
command has a non-zero exit code and `check` is True.
"""


class CommandResponse(BaseModel):
stdout: str = ""
Expand Down Expand Up @@ -169,7 +190,17 @@ class _ExceptionTransfer(BaseModel):
traceback: str = ""


class SweRexception(RuntimeError): ...
# todo: move?
class SweRexception(Exception):
"""Any exception that is raised by SWE-Rex."""


class SessionNotInitializedError(SweRexception, RuntimeError):
"""Raised if we try to run a command in a shell that is not initialized."""


class NonZeroExitCodeError(SweRexception, RuntimeError):
"""Can be raised if we execute a command in the shell and it has a non-zero exit code."""


class BashIncorrectSyntaxError(SweRexception, RuntimeError):
Expand All @@ -178,10 +209,7 @@ class BashIncorrectSyntaxError(SweRexception, RuntimeError):
"""


class UninitializedShellError(SweRexception, ValueError): ...


class CommandTimeoutError(SweRexception, RuntimeError): ...
class CommandTimeoutError(SweRexception, RuntimeError, TimeoutError): ...


class NoExitCodeError(SweRexception, RuntimeError): ...
Expand Down
43 changes: 39 additions & 4 deletions src/swerex/runtime/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
CreateSessionResponse,
IsAliveResponse,
NoExitCodeError,
NonZeroExitCodeError,
Observation,
ReadFileRequest,
ReadFileResponse,
SessionDoesNotExistError,
SessionExistsError,
SessionNotInitializedError,
UploadRequest,
UploadResponse,
WriteFileRequest,
Expand Down Expand Up @@ -165,12 +167,31 @@ async def start(self) -> CreateBashSessionResponse:
return CreateBashSessionResponse(output=output)

async def run(self, action: BashAction) -> BashObservation:
"""Run a bash action.
Raises:
SessionNotInitializedError: If the shell is not initialized.
CommandTimeoutError: If the command times out.
NonZeroExitCodeError: If the command has a non-zero exit code and `action.check` is True.
NoExitCodeError: If we cannot get the exit code of the command.
Returns:
BashObservation: The observation of the command.
"""
if self.shell is None:
msg = "shell not initialized"
raise RuntimeError(msg)
raise SessionNotInitializedError(msg)
if action.is_interactive_command or action.is_interactive_quit:
return await self._run_interactive(action)
return await self._run_normal(action)
r = await self._run_normal(action)
if action.check and r.exit_code != 0:
msg = (
f"Command {action.command!r} failed with exit code {r.exit_code}. " "Here is the output:\n{r.output!r}"
)
if action.error_msg:
msg = f"{action.error_msg}: {msg}"
raise NonZeroExitCodeError(msg)
return r

async def _run_interactive(self, action: BashAction) -> BashObservation:
"""Run an interactive action. This is different because we don't seek to
Expand Down Expand Up @@ -326,17 +347,31 @@ async def close_session(self, request: CloseSessionRequest) -> CloseSessionRespo
return out

async def execute(self, command: Command) -> CommandResponse:
"""Executes a command (independent of any shell session)."""
"""Executes a command (independent of any shell session).
Raises:
CommandTimeoutError: If the command times out.
NonZeroExitCodeError: If the command has a non-zero exit code and `check` is True.
"""
try:
result = subprocess.run(command.command, shell=command.shell, timeout=command.timeout, capture_output=True)
return CommandResponse(
r = CommandResponse(
stdout=result.stdout.decode(errors="backslashreplace"),
stderr=result.stderr.decode(errors="backslashreplace"),
exit_code=result.returncode,
)
except subprocess.TimeoutExpired as e:
msg = f"Timeout ({command.timeout}s) exceeded while running command"
raise CommandTimeoutError(msg) from e
if command.check and result.returncode != 0:
msg = (
f"Command {command.command!r} failed with exit code {result.returncode}. "
"Stdout:\n{r.stdout!r}\nStderr:\n{r.stderr!r}"
)
if command.error_msg:
msg = f"{command.error_msg}: {msg}"
raise NonZeroExitCodeError(msg)
return r

async def read_file(self, request: ReadFileRequest) -> ReadFileResponse:
"""Reads a file"""
Expand Down
112 changes: 51 additions & 61 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CloseBashSessionRequest,
CommandTimeoutError,
CreateBashSessionRequest,
NonZeroExitCodeError,
ReadFileRequest,
SessionDoesNotExistError,
UploadRequest,
Expand Down Expand Up @@ -59,14 +60,17 @@ async def test_create_close_shell(remote_runtime: RemoteRuntime):


async def test_run_in_shell(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="echo 'hello world'"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="echo 'hello world'", check=True))
r = await runtime_with_default_session.run_in_session(A(command="doesntexit"))
assert r.exit_code == 127
r = await runtime_with_default_session.run_in_session(A(command="false && true"))
assert r.exit_code == 1
r = await runtime_with_default_session.run_in_session(A(command="false || true"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="false || true", check=True))


async def test_run_in_shell_check_exit_code(runtime_with_default_session: RemoteRuntime):
with pytest.raises(NonZeroExitCodeError):
await runtime_with_default_session.run_in_session(A(command="false", check=True))


async def test_run_in_shell_non_existent_session(remote_runtime: RemoteRuntime):
Expand All @@ -93,43 +97,30 @@ async def test_run_in_shell_timeout(runtime_with_default_session: RemoteRuntime)


async def test_run_in_shell_interactive_command(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(
A(command="python", is_interactive_command=True, expect=[">>> "])
)

r = await runtime_with_default_session.run_in_session(
await runtime_with_default_session.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))
await runtime_with_default_session.run_in_session(
A(command="print('hello world')", is_interactive_command=True, expect=[">>> "])
)

r = await runtime_with_default_session.run_in_session(A(command="quit()\n", is_interactive_quit=True))
assert r.exit_code == 0
await runtime_with_default_session.run_in_session(A(command="quit()\n", is_interactive_quit=True, check=True))


async def test_run_in_shell_multiple_interactive_and_normal_commands(runtime_with_default_session: RemoteRuntime):
run = runtime_with_default_session
r = await run.run_in_session(A(command="ls"))
assert r.exit_code == 0
r = await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))
await run.run_in_session(A(command="ls", check=True))
await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))

r = await run.run_in_session(A(command="print('hello world')", is_interactive_command=True, expect=[">>> "]))
assert "hello world" in r.output

r = await run.run_in_session(A(command="quit()\n", is_interactive_quit=True))
assert r.exit_code == 0
await run.run_in_session(A(command="quit()\n", is_interactive_quit=True, check=True))

r = await run.run_in_session(A(command="echo 'hello world'"))
assert r.exit_code == 0
r = await run.run_in_session(A(command="echo 'hello world'", check=True))
assert "hello world" in r.output

r = await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))

r = await run.run_in_session(A(command="print('hello world')", is_interactive_command=True, expect=[">>> "]))

r = await run.run_in_session(A(command="quit()\n", is_interactive_quit=True))
assert r.exit_code == 0

r = await run.run_in_session(A(command="echo 'hello world'"))
assert r.exit_code == 0
await run.run_in_session(A(command="python", is_interactive_command=True, expect=[">>> "]))
await run.run_in_session(A(command="print('hello world')", is_interactive_command=True, expect=[">>> "]))
await run.run_in_session(A(command="quit()\n", is_interactive_quit=True, check=True))
r = await run.run_in_session(A(command="echo 'hello world'", check=True))
assert "hello world" in r.output


Expand Down Expand Up @@ -159,18 +150,18 @@ async def test_multiple_isolated_shells(remote_runtime: RemoteRuntime):
await remote_runtime.create_session(CreateBashSessionRequest(session="shell2"))

await asyncio.gather(
remote_runtime.run_in_session(A(command="x=42", session="shell1")),
remote_runtime.run_in_session(A(command="y=24", session="shell2")),
remote_runtime.run_in_session(A(command="x=42", session="shell1", check=True)),
remote_runtime.run_in_session(A(command="y=24", session="shell2", check=True)),
)

response1 = await remote_runtime.run_in_session(A(command="echo $x", session="shell1"))
response2 = await remote_runtime.run_in_session(A(command="echo $y", session="shell2"))
response1 = await remote_runtime.run_in_session(A(command="echo $x", session="shell1", check=True))
response2 = await remote_runtime.run_in_session(A(command="echo $y", session="shell2", check=True))

assert response1.output.strip() == "42"
assert response2.output.strip() == "24"

response3 = await remote_runtime.run_in_session(A(command="echo $y", session="shell1"))
response4 = await remote_runtime.run_in_session(A(command="echo $x", session="shell2"))
response3 = await remote_runtime.run_in_session(A(command="echo $y", session="shell1", check=True))
response4 = await remote_runtime.run_in_session(A(command="echo $x", session="shell2", check=True))

assert response3.output.strip() == ""
assert response4.output.strip() == ""
Expand All @@ -180,69 +171,69 @@ async def test_multiple_isolated_shells(remote_runtime: RemoteRuntime):


async def test_empty_command(remote_runtime: RemoteRuntime):
await remote_runtime.execute(C(command="", shell=True))
await remote_runtime.execute(C(command="\n", shell=True))
await remote_runtime.execute(C(command="", shell=True, check=True))
await remote_runtime.execute(C(command="\n", shell=True, check=True))


async def test_command_fails_check_exit_code(runtime_with_default_session: RemoteRuntime):
with pytest.raises(NonZeroExitCodeError):
await runtime_with_default_session.run_in_session(A(command="false", check=True))


async def test_empty_command_in_shell(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command=""))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="\n"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="\n\n \n"))
assert r.exit_code == 0
await runtime_with_default_session.run_in_session(A(command="", check=True))
await runtime_with_default_session.run_in_session(A(command="\n", check=True))
await runtime_with_default_session.run_in_session(A(command="\n\n \n", check=True))


async def test_command_with_linebreaks(runtime_with_default_session: RemoteRuntime):
await runtime_with_default_session.run_in_session(A(command="\n echo 'test'\n\n"))
await runtime_with_default_session.run_in_session(A(command="\n echo 'test'\n\n", check=True))


async def test_multiple_commands_with_linebreaks_in_shell(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(
A(command="\n\n\n echo 'test1' \n \n \n echo 'test2' \n\n\n")
A(command="\n\n\n echo 'test1' \n \n \n echo 'test2' \n\n\n", check=True)
)
assert r.exit_code == 0
assert r.output.splitlines() == ["test1", "test2"]


async def test_bash_multiline_command_eof(runtime_with_default_session: RemoteRuntime):
command = "\n".join(["python <<EOF", "print('hello world')", "print('hello world 2')", "EOF"])
r = await runtime_with_default_session.run_in_session(A(command=command))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command=command, check=True))
assert "hello world" in r.output
assert "hello world 2" in r.output


async def test_run_in_shell_subshell_command(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="(sleep 10) &"))
assert r.exit_code == 0
await runtime_with_default_session.run_in_session(A(command="(sleep 10) &", check=True))


async def test_run_just_comment(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="# echo 'hello world'"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="# echo 'hello world'", check=True))
assert r.output == ""


async def test_run_in_shell_multiple_commands(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="echo 'hello world'; echo 'hello again'"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(
A(command="echo 'hello world'; echo 'hello again'", check=True)
)
assert r.output.splitlines() == ["hello world", "hello again"]
r = await runtime_with_default_session.run_in_session(A(command="echo 'hello world' && echo 'hello again'"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(
A(command="echo 'hello world' && echo 'hello again'", check=True)
)
assert r.output.splitlines() == ["hello world", "hello again"]


async def test_run_in_shell_while_loop(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="for i in {1..3};\n do echo 'hello world';\n done"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(
A(command="for i in {1..3};\n do echo 'hello world';\n done", check=True)
)
assert r.output.splitlines() == ["hello world"] * 3


async def test_run_in_shell_bashlex_errors(runtime_with_default_session: RemoteRuntime):
# One of the bugs in bashlex
r = await runtime_with_default_session.run_in_session(A(command="[[ $env == $env ]]"))
assert r.exit_code == 0
await runtime_with_default_session.run_in_session(A(command="[[ $env == $env ]]", check=True))


async def test_run_shell_check_exit_code(runtime_with_default_session: RemoteRuntime):
Expand All @@ -251,8 +242,7 @@ async def test_run_shell_check_exit_code(runtime_with_default_session: RemoteRun


async def test_with_bashlex_errors(runtime_with_default_session: RemoteRuntime):
r = await runtime_with_default_session.run_in_session(A(command="echo 'hw';A=();echo 'asdf'"))
assert r.exit_code == 0
r = await runtime_with_default_session.run_in_session(A(command="echo 'hw';A=();echo 'asdf'", check=True))
assert "hw" in r.output
assert "asdf" in r.output

Expand Down

0 comments on commit 6b274e4

Please sign in to comment.