diff --git a/nemo_skills/code_execution/sandbox.py b/nemo_skills/code_execution/sandbox.py index cc83b4e7bf..891e61111c 100644 --- a/nemo_skills/code_execution/sandbox.py +++ b/nemo_skills/code_execution/sandbox.py @@ -46,6 +46,10 @@ class Sandbox(abc.ABC): Can also be specified through NEMO_SKILLS_SSH_SERVER env var. ssh_key_path: Optional[str] = None - Path to the ssh key for tunneling. Can also be specified through NEMO_SKILLS_SSH_KEY_PATH env var. + disable_session_restore: bool = False - When True, skip replaying session history + after a sandbox worker restarts. The current code executes on a fresh session + and the model receives a warning in stderr. + Can also be specified through NEMO_SKILLS_DISABLE_SESSION_RESTORE env var. """ def __init__( @@ -54,6 +58,7 @@ def __init__( port: Optional[str] = os.getenv("NEMO_SKILLS_SANDBOX_PORT", "6000"), ssh_server: Optional[str] = None, ssh_key_path: Optional[str] = None, + disable_session_restore: bool = False, ): self.host = host self.port = port @@ -63,6 +68,9 @@ def __init__( ) self.ssh_server = os.getenv("NEMO_SKILLS_SSH_SERVER", ssh_server) self.ssh_key_path = os.getenv("NEMO_SKILLS_SSH_KEY_PATH", ssh_key_path) + self.disable_session_restore = disable_session_restore or os.getenv( + "NEMO_SKILLS_DISABLE_SESSION_RESTORE", "" + ).lower() not in ("", "0", "false") self.session_histories = defaultdict(list) # session_id -> list of generated_code async def close(self): @@ -171,7 +179,12 @@ async def execute_code( # NOTE: Only cells that completed successfully are stored, so we intentionally omit re-running cells that errored # or timed out. This means restoration **can diverge** from the original interactive session in those cases, but # avoids re-triggering side effects from failing cells while keeping the replay simple. - if session_id is not None and new_session_created and output.get("process_status") != "timeout": + if ( + session_id is not None + and new_session_created + and output.get("process_status") != "timeout" + and not self.disable_session_restore + ): history = list(self.session_histories.get(session_id_str, [])) if request_session_id_str is not None: try: @@ -236,8 +249,24 @@ async def execute_code( except httpx.TimeoutException: output = {"process_status": "timeout", "stdout": "", "stderr": "Client timed out\n"} - # Append to history if successful execution (process_status == 'completed') - if output.get("process_status") == "completed" and request_session_id_str is not None: + elif session_id is not None and new_session_created and self.disable_session_restore: + # Session was recreated but restore is disabled — clear stale history and warn the model. + self.session_histories.pop(session_id_str, None) + self.session_histories.pop(request_session_id_str, None) + LOG.warning("Session %s was recreated but restore is disabled; history cleared", session_id) + output["stderr"] = ( + "RuntimeError: Sandbox state restoration failed after the execution worker restarted. " + "The interactive session history has been cleared; " + "please re-run the last code block without relying on prior state.\n" + ) + output.get("stderr", "") + + # Append to history if successful execution (process_status == 'completed'). + # Skip when restore is disabled — history will never be replayed and would leak memory. + if ( + output.get("process_status") == "completed" + and request_session_id_str is not None + and not self.disable_session_restore + ): self.session_histories[request_session_id_str].append(generated_code) output.pop("new_session_created", None) diff --git a/nemo_skills/mcp/servers/python_tool.py b/nemo_skills/mcp/servers/python_tool.py index 44ed7cdecf..cce3d9c301 100644 --- a/nemo_skills/mcp/servers/python_tool.py +++ b/nemo_skills/mcp/servers/python_tool.py @@ -68,6 +68,12 @@ async def stateful_python_code_exec( def main(): parser = argparse.ArgumentParser(description="MCP server for executing Python code in a sandbox") + parser.add_argument( + "--disable-session-restore", + action="store_true", + default=False, + help="Skip replaying session history after sandbox worker restarts (overrides config)", + ) add_config_args(parser) args = parser.parse_args() @@ -83,6 +89,9 @@ def main(): global sandbox sandbox_cfg = OmegaConf.to_container(cfg.sandbox, resolve=True) + if args.disable_session_restore: + sandbox_cfg["disable_session_restore"] = True + sandbox = get_sandbox(**sandbox_cfg) # Initialize and run the server mcp.run(transport="stdio")