diff --git a/dockerfiles/Dockerfile.sandbox b/dockerfiles/Dockerfile.sandbox index 516df05bd5..f097209335 100644 --- a/dockerfiles/Dockerfile.sandbox +++ b/dockerfiles/Dockerfile.sandbox @@ -67,3 +67,6 @@ ARG UWSGI_PROCESSES ENV UWSGI_PROCESSES=$UWSGI_PROCESSES ENV LISTEN_PORT=6000 + + +RUN echo "uwsgi_read_timeout 14400s;" > /etc/nginx/conf.d/custom_timeout.conf \ No newline at end of file diff --git a/nemo_skills/code_execution/local_sandbox/local_sandbox_server.py b/nemo_skills/code_execution/local_sandbox/local_sandbox_server.py index 2d5938cd40..e0a4b311e5 100644 --- a/nemo_skills/code_execution/local_sandbox/local_sandbox_server.py +++ b/nemo_skills/code_execution/local_sandbox/local_sandbox_server.py @@ -22,6 +22,7 @@ import tempfile import signal from io import StringIO +import psutil from flask import Flask, request @@ -29,6 +30,41 @@ MEM_LIMIT_BYTES = int(os.environ.get('NEMO_SKILLS_SANDBOX_MEM_LIMIT', 10 * 1024 ** 3)) # 10 GiB default +# Code to kill the process tree for lean4 code execution +def kill_process_tree(proc): + """ + Safely and aggressively kills a process and all its descendants. + This is the recommended approach for ensuring cleanup. + """ + try: + parent = psutil.Process(proc.pid) + # Get all children of the process, recursively. + children = parent.children(recursive=True) + # Add the parent to the list of processes to be killed. + all_processes = children + [parent] + + # Kill all processes in the tree. + for p in all_processes: + try: + # SIGKILL is a forceful, non-ignorable kill signal. + p.kill() + except psutil.NoSuchProcess: + # The process might have already died, which is fine. + pass + + # Wait for all processes to be terminated. + gone, alive = psutil.wait_procs(all_processes, timeout=3) + if alive: + # If any processes are still alive, they are likely zombies + # or in an unkillable state. This is a last resort. + for p in alive: + print(f"Warning: Process {p.pid} could not be killed.") + except psutil.NoSuchProcess: + # The main process already died before we could kill it. + pass + except Exception as e: + print(f"Error in kill_process_tree: {e}") + def set_limits(mem_bytes: int = MEM_LIMIT_BYTES) -> None: """ Apply RLIMITs and start a new session for the child process. @@ -79,33 +115,54 @@ def execute_python(generated_code, std_input, timeout, language): def execute_lean4(generated_code, timeout): temp_file_name = None + proc = None # <-- Keep track of the process object try: project_path = "/lean4/my_project" + # Use a with statement for the temp file to ensure it's closed with tempfile.NamedTemporaryFile(dir=project_path, delete=False, suffix=".lean") as temp_file: temp_file_name = temp_file.name temp_file.write(generated_code.encode('utf-8')) + temp_file.flush() # Ensure data is written to disk - result = subprocess.run( + # Use subprocess.Popen for more control + proc = subprocess.Popen( ['lake', 'env', '--dir', project_path, 'lean', temp_file_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE, - timeout=timeout, - cwd=project_path, # Ensure we are in the correct working directory + cwd=project_path, + preexec_fn=os.setsid ) - if result.returncode == 0: + # Communicate with the process, which waits for it to finish + # This will raise TimeoutExpired if the timeout is reached + stdout, stderr = proc.communicate(timeout=timeout) + + if proc.returncode == 0: process_status = "completed" else: process_status = "failed" return { "process_status": process_status, - "stdout": result.stdout.decode('utf-8'), - "stderr": result.stderr.decode('utf-8'), + "stdout": stdout.decode('utf-8'), + "stderr": stderr.decode('utf-8'), } except subprocess.TimeoutExpired: - return {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"} + + # kill the process tree + kill_process_tree(proc) + + # Now we can safely get any output that was generated before the kill. + stdout, stderr = proc.communicate() + + final_stderr = stderr.decode('utf-8') + "Timed out\n" + return { + "process_status": "timeout", + "stdout": stdout.decode('utf-8'), + "stderr": final_stderr, + } + except Exception as e: print(f"Error: {str(e)}") return {"process_status": "error", "stdout": "", "stderr": str(e) + "\n"} @@ -149,4 +206,4 @@ def execute(): if __name__ == '__main__': log = logging.getLogger('werkzeug') log.setLevel(logging.WARNING) - app.run(port=6000) + app.run(port=6000) \ No newline at end of file diff --git a/requirements/code_execution.txt b/requirements/code_execution.txt index 546492f7fe..4095a94b07 100644 --- a/requirements/code_execution.txt +++ b/requirements/code_execution.txt @@ -23,3 +23,4 @@ pandas scipy sympy tqdm +psutil \ No newline at end of file