Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions nemo_skills/code_execution/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class Sandbox(abc.ABC):
Can also be specified through NEMO_SKILLS_SSH_SERVER env var.
ssh_key_path: Optional[str] = None - Path to the ssh key for tunneling.
Can also be specified through NEMO_SKILLS_SSH_KEY_PATH env var.
disable_session_restore: bool = False - When True, skip replaying session history
after a sandbox worker restarts. The current code executes on a fresh session
and the model receives a warning in stderr.
Can also be specified through NEMO_SKILLS_DISABLE_SESSION_RESTORE env var.
"""

def __init__(
Expand All @@ -54,6 +58,7 @@ def __init__(
port: Optional[str] = os.getenv("NEMO_SKILLS_SANDBOX_PORT", "6000"),
ssh_server: Optional[str] = None,
ssh_key_path: Optional[str] = None,
disable_session_restore: bool = False,
):
self.host = host
self.port = port
Expand All @@ -63,6 +68,9 @@ def __init__(
)
self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER", ssh_server)
self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH", ssh_key_path)
self.disable_session_restore = disable_session_restore or os.getenv(
"NEMO_SKILLS_DISABLE_SESSION_RESTORE", ""
).lower() not in ("", "0", "false")
self.session_histories = defaultdict(list) # session_id -> list of generated_code

async def close(self):
Expand Down Expand Up @@ -171,7 +179,12 @@ async def execute_code(
# NOTE: Only cells that completed successfully are stored, so we intentionally omit re-running cells that errored
# or timed out. This means restoration **can diverge** from the original interactive session in those cases, but
# avoids re-triggering side effects from failing cells while keeping the replay simple.
if session_id is not None and new_session_created and output.get("process_status") != "timeout":
if (
session_id is not None
and new_session_created
and output.get("process_status") != "timeout"
and not self.disable_session_restore
):
history = list(self.session_histories.get(session_id_str, []))
if request_session_id_str is not None:
try:
Expand Down Expand Up @@ -236,8 +249,24 @@ async def execute_code(
except httpx.TimeoutException:
output = {"process_status": "timeout", "stdout": "", "stderr": "Client timed out\n"}

# Append to history if successful execution (process_status == 'completed')
if output.get("process_status") == "completed" and request_session_id_str is not None:
elif session_id is not None and new_session_created and self.disable_session_restore:
# Session was recreated but restore is disabled — clear stale history and warn the model.
self.session_histories.pop(session_id_str, None)
self.session_histories.pop(request_session_id_str, None)
LOG.warning("Session %s was recreated but restore is disabled; history cleared", session_id)
output["stderr"] = (
"RuntimeError: Sandbox state restoration failed after the execution worker restarted. "
"The interactive session history has been cleared; "
"please re-run the last code block without relying on prior state.\n"
) + output.get("stderr", "")
Comment on lines +252 to +261
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic bug: the elif doesn't cover the case where session restore is disabled but new_session_created is False

When disable_session_restore=True and new_session_created=False (normal execution without restart), the code falls through without executing anything. The warning message and history clearing should still happen when a restart is detected, but successful code history should never be stored when restore is disabled (which is handled correctly on line 268).

However, the logic structure with elif prevents the code block on lines 243-250 from executing when disable_session_restore=True and the session wasn't recreated. This means normal code execution might be skipped.

Suggested change
elif session_id is not None and new_session_created and self.disable_session_restore:
# Session was recreated but restore is disabled — clear stale history and warn the model.
self.session_histories.pop(session_id_str, None)
self.session_histories.pop(request_session_id_str, None)
LOG.warning("Session %s was recreated but restore is disabled; history cleared", session_id)
output["stderr"] = (
"RuntimeError: Sandbox state restoration failed after the execution worker restarted. "
"The interactive session history has been cleared; "
"please re-run the last code block without relying on prior state.\n"
) + output.get("stderr", "")
elif session_id is not None and new_session_created and self.disable_session_restore:
# Session was recreated but restore is disabled — clear stale history and warn the model.
self.session_histories.pop(session_id_str, None)
self.session_histories.pop(request_session_id_str, None)
LOG.warning("Session %s was recreated but restore is disabled; history cleared", session_id)
output["stderr"] = (
"RuntimeError: Sandbox state restoration failed after the execution worker restarted. "
"The interactive session history has been cleared; "
"please re-run the last code block without relying on prior state.\n"
) + output.get("stderr", "")
# Execute the new code on the fresh session
exec_request = self._prepare_request(
generated_code, timeout, language, std_input, max_output_characters, traceback_verbosity
)
exec_request["session_id"] = request_session_id_str
try:
output = await self._send_request(exec_request, timeout)
except httpx.TimeoutException:
output = {"process_status": "timeout", "stdout": "", "stderr": "Client timed out\n"}


# Append to history if successful execution (process_status == 'completed').
# Skip when restore is disabled — history will never be replayed and would leak memory.
if (
output.get("process_status") == "completed"
and request_session_id_str is not None
and not self.disable_session_restore
):
self.session_histories[request_session_id_str].append(generated_code)

output.pop("new_session_created", None)
Expand Down
9 changes: 9 additions & 0 deletions nemo_skills/mcp/servers/python_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ async def stateful_python_code_exec(

def main():
parser = argparse.ArgumentParser(description="MCP server for executing Python code in a sandbox")
parser.add_argument(
"--disable-session-restore",
action="store_true",
default=False,
help="Skip replaying session history after sandbox worker restarts (overrides config)",
)
add_config_args(parser)
args = parser.parse_args()

Expand All @@ -83,6 +89,9 @@ def main():

global sandbox
sandbox_cfg = OmegaConf.to_container(cfg.sandbox, resolve=True)
if args.disable_session_restore:
sandbox_cfg["disable_session_restore"] = True

sandbox = get_sandbox(**sandbox_cfg)
# Initialize and run the server
mcp.run(transport="stdio")
Expand Down