Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dockerfiles/Dockerfile.sandbox
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ ARG UWSGI_PROCESSES
ENV UWSGI_PROCESSES=$UWSGI_PROCESSES

ENV LISTEN_PORT=6000


RUN echo "uwsgi_read_timeout 14400s;" > /etc/nginx/conf.d/custom_timeout.conf
73 changes: 65 additions & 8 deletions nemo_skills/code_execution/local_sandbox/local_sandbox_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,49 @@
import tempfile
import signal
from io import StringIO
import psutil

from flask import Flask, request

app = Flask(__name__)

MEM_LIMIT_BYTES = int(os.environ.get('NEMO_SKILLS_SANDBOX_MEM_LIMIT', 10 * 1024 ** 3)) # 10 GiB default

# Code to kill the process tree for lean4 code execution
def kill_process_tree(proc):
"""
Safely and aggressively kills a process and all its descendants.
This is the recommended approach for ensuring cleanup.
"""
try:
parent = psutil.Process(proc.pid)
# Get all children of the process, recursively.
children = parent.children(recursive=True)
# Add the parent to the list of processes to be killed.
all_processes = children + [parent]

# Kill all processes in the tree.
for p in all_processes:
try:
# SIGKILL is a forceful, non-ignorable kill signal.
p.kill()
except psutil.NoSuchProcess:
# The process might have already died, which is fine.
pass

# Wait for all processes to be terminated.
gone, alive = psutil.wait_procs(all_processes, timeout=3)
if alive:
# If any processes are still alive, they are likely zombies
# or in an unkillable state. This is a last resort.
for p in alive:
print(f"Warning: Process {p.pid} could not be killed.")
except psutil.NoSuchProcess:
# The main process already died before we could kill it.
pass
except Exception as e:
print(f"Error in kill_process_tree: {e}")

def set_limits(mem_bytes: int = MEM_LIMIT_BYTES) -> None:
"""
Apply RLIMITs and start a new session for the child process.
Expand Down Expand Up @@ -79,33 +115,54 @@ def execute_python(generated_code, std_input, timeout, language):

def execute_lean4(generated_code, timeout):
temp_file_name = None
proc = None # <-- Keep track of the process object
try:
project_path = "/lean4/my_project"
# Use a with statement for the temp file to ensure it's closed
with tempfile.NamedTemporaryFile(dir=project_path, delete=False, suffix=".lean") as temp_file:
temp_file_name = temp_file.name
temp_file.write(generated_code.encode('utf-8'))
temp_file.flush() # Ensure data is written to disk

result = subprocess.run(
# Use subprocess.Popen for more control
proc = subprocess.Popen(
['lake', 'env', '--dir', project_path, 'lean', temp_file_name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=timeout,
cwd=project_path, # Ensure we are in the correct working directory
cwd=project_path,
preexec_fn=os.setsid
)

if result.returncode == 0:
# Communicate with the process, which waits for it to finish
# This will raise TimeoutExpired if the timeout is reached
stdout, stderr = proc.communicate(timeout=timeout)

if proc.returncode == 0:
process_status = "completed"
else:
process_status = "failed"

return {
"process_status": process_status,
"stdout": result.stdout.decode('utf-8'),
"stderr": result.stderr.decode('utf-8'),
"stdout": stdout.decode('utf-8'),
"stderr": stderr.decode('utf-8'),
}

except subprocess.TimeoutExpired:
return {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"}

# kill the process tree
kill_process_tree(proc)

# Now we can safely get any output that was generated before the kill.
stdout, stderr = proc.communicate()

final_stderr = stderr.decode('utf-8') + "Timed out\n"
return {
"process_status": "timeout",
"stdout": stdout.decode('utf-8'),
"stderr": final_stderr,
}

except Exception as e:
print(f"Error: {str(e)}")
return {"process_status": "error", "stdout": "", "stderr": str(e) + "\n"}
Expand Down Expand Up @@ -149,4 +206,4 @@ def execute():
if __name__ == '__main__':
log = logging.getLogger('werkzeug')
log.setLevel(logging.WARNING)
app.run(port=6000)
app.run(port=6000)
1 change: 1 addition & 0 deletions requirements/code_execution.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pandas
scipy
sympy
tqdm
psutil