Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1fe8f5a
ENH add single eval in generation loop
gwarmstrong Sep 18, 2025
3f75f42
ENH make migration path cleaner
gwarmstrong Sep 18, 2025
5b415c3
MAINT move proof construction to shared utility
gwarmstrong Sep 18, 2025
b2fa5d6
Merge branch 'main' into georgea/eval-single-in-gen-loop
gwarmstrong Sep 19, 2025
ec2ec26
Update nemo_skills/evaluation/evaluator/__init__.py
gwarmstrong Sep 23, 2025
cd5775b
Merge branch 'main' into georgea/eval-single-in-gen-loop
gwarmstrong Sep 23, 2025
9db3e4e
change name to get_evaluator_class
gwarmstrong Sep 23, 2025
56e3246
MAINT remove legacy eval methods
gwarmstrong Sep 23, 2025
cd45275
MAINT remove unused imports
gwarmstrong Sep 23, 2025
276aa9f
MAINT refactor eval_hook to semaphore level
gwarmstrong Sep 23, 2025
07dae3d
FIX apply coderabbit suggest for async processing
gwarmstrong Sep 23, 2025
5f56c4c
MAINT pipeline utils fixes
gwarmstrong Sep 23, 2025
1220850
MAINT move math to single eval pattern
gwarmstrong Sep 23, 2025
7385739
MAINT remove blanket exception
gwarmstrong Sep 23, 2025
d99060a
MAINT make lean eval use base and upgrade base with parallel
gwarmstrong Sep 23, 2025
ca25770
FIX eval base cases
gwarmstrong Sep 23, 2025
b086d88
MAINT add interleaved_eval_single_time
gwarmstrong Sep 23, 2025
af6a5e7
MAINT move sandbox instantiation into init
gwarmstrong Sep 23, 2025
701592c
ENH collect in order with eval_full
gwarmstrong Sep 23, 2025
b338ec2
MAINT fix sandbox batch execution
gwarmstrong Sep 23, 2025
cf25086
FIX tasks loop
gwarmstrong Sep 23, 2025
0a64101
MAINT add todo for dependent evaluations
gwarmstrong Sep 23, 2025
ba887a7
FIX math config instantiation
gwarmstrong Sep 23, 2025
ff28e7a
MAINT remove extra kwargs
gwarmstrong Sep 23, 2025
740511e
ENH simplify task size inference
gwarmstrong Sep 23, 2025
967e0f7
Fixes updates merge
gwarmstrong Sep 23, 2025
a4ff6e1
MAINT fix format
gwarmstrong Sep 23, 2025
a46cf46
FIX just directly fail if update isn't able to update the data
gwarmstrong Sep 23, 2025
272b962
Update nemo_skills/evaluation/evaluator/base.py
gwarmstrong Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions nemo_skills/code_execution/proof_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared utilities for proof processing and evaluation."""

import re
from dataclasses import dataclass
from typing import Any, Dict

from nemo_skills.code_execution.utils import clean_formal_generation
from nemo_skills.dataset.utils import get_lean4_header


@dataclass
class ProofBuildConfig:
"""Configuration for building proofs from generations."""

final_answer_key: str = "**FINAL ANSWER**"
extract_code_mode: str = "last" # "first" or "last"
restate_formal_statement: bool = True
strip_theorem_from_proof: bool = True


def extract_proof_only(lean_code: str) -> str:
"""Extract only the proof part from a Lean theorem/example.

This function removes the theorem/example header and returns just the proof body.

Args:
lean_code: The full Lean code including theorem statement

Returns:
The proof part only (everything after ':=')
"""
lines = lean_code.strip().splitlines()
if not lines:
return ""

header_start_pattern = re.compile(r"^\s*(theorem|example)\b")
header_start_idx = None

# 1. Find where the theorem starts
for i, line in enumerate(lines):
if header_start_pattern.match(line):
header_start_idx = i
break

if header_start_idx is None:
return lean_code.strip()

# 2. Find where ':=' occurs, starting from the header
header_end_idx = None
for i in range(header_start_idx, len(lines)):
if ":=" in lines[i]:
header_end_idx = i
break

if header_end_idx is None:
return lean_code.strip()

# 3. Extract the line after ':='
header_line, after = lines[header_end_idx].split(":=", 1)
proof_first_line = after.strip()

# 4. Collect proof lines
if proof_first_line:
proof_lines = [proof_first_line] + lines[header_end_idx + 1 :]
else:
proof_lines = lines[header_end_idx + 1 :]

# 5. Remove leading 'by' (with or without indentation)
if proof_lines:
first = proof_lines[0].lstrip()
if first == "by":
proof_lines = proof_lines[1:]
elif first.startswith("by "):
proof_lines[0] = first[3:] # Strip 'by '

return "\n".join(proof_lines).rstrip()


def build_lean4_proof(
generation: str, data_point: Dict[str, Any], config: ProofBuildConfig, answer_format: str = "lean4-proof"
) -> str:
"""Build a complete Lean4 proof from generation and data point.

Args:
generation: The raw generation from the model
data_point: Dictionary containing header, formal_statement, etc.
config: Configuration for proof building
answer_format: Either "lean4-proof" or "lean4-statement"

Returns:
Complete Lean4 proof ready for execution
"""
if answer_format == "lean4-proof":
# Clean the generation and extract the formal proof
cleaned_generation = clean_formal_generation(
generation, final_answer_key=config.final_answer_key, extract_code_mode=config.extract_code_mode
)

# Combine header + formal_statement + proof
header = data_point.get("header", "")
formal_statement = data_point.get("formal_statement", "") if config.restate_formal_statement else ""

if config.strip_theorem_from_proof:
proof_part = extract_proof_only(cleaned_generation)
else:
proof_part = cleaned_generation

predicted_proof = header + formal_statement + proof_part

elif answer_format == "lean4-statement":
# For statements, add header and append sorry
cleaned_generation = clean_formal_generation(generation, extract_code_mode=config.extract_code_mode)
header = get_lean4_header()
predicted_proof = header + cleaned_generation + "\n sorry"

else:
raise ValueError(f"Unknown answer_format: {answer_format}")

return predicted_proof


def determine_proof_status(compiler_output: Dict[str, Any]) -> str:
"""Determine proof status from compiler output.

Args:
compiler_output: Dictionary containing process_status, stdout, stderr

Returns:
Status string: "completed", "timeout", "has_sorry", or other process status
"""
process_status = compiler_output.get("process_status", "unknown")

if process_status == "timeout":
return "timeout"
elif process_status != "completed":
return process_status

# Check stdout and stderr for proof status indicators
stdout = compiler_output.get("stdout", "").lower()
stderr = compiler_output.get("stderr", "").lower()
combined = stdout + "\n" + stderr

# Check for sorry (incomplete proof)
if re.search(r"\bsorry\b", combined) is not None:
return "has_sorry"

# If process completed without errors, consider it successful
return "completed"


def prepare_predicted_proof_from_line_dict(
line_dict: Dict[str, Any],
config: ProofBuildConfig,
answer_format: str = "lean4-proof",
use_predicted_proof_key: bool = False,
) -> str:
"""Prepare predicted_proof from a line dictionary (for batch processing).

Args:
line_dict: Dictionary containing generation and other fields
config: Configuration for proof building
answer_format: Either "lean4-proof" or "lean4-statement"
use_predicted_proof_key: If True, use existing predicted_proof key

Returns:
Complete Lean4 proof ready for execution

Raises:
ValueError: If use_predicted_proof_key is True but key is missing
"""
if use_predicted_proof_key:
if "predicted_proof" not in line_dict:
raise ValueError(
"predicted_proof key not found in the line_dict. Set use_predicted_proof_key=False to re-combine"
)
return line_dict["predicted_proof"]

return build_lean4_proof(
generation=line_dict["generation"], data_point=line_dict, config=config, answer_format=answer_format
)
118 changes: 24 additions & 94 deletions nemo_skills/code_execution/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import json
import logging
import os
import re
import traceback
import uuid
from collections import defaultdict
Expand All @@ -27,8 +26,11 @@
import httpx
import tqdm

from nemo_skills.code_execution.utils import clean_formal_generation
from nemo_skills.dataset.utils import get_lean4_header
from nemo_skills.code_execution.proof_utils import (
ProofBuildConfig,
determine_proof_status,
prepare_predicted_proof_from_line_dict,
)
from nemo_skills.utils import get_logger_name, python_doc_to_cmd_help

LOG = logging.getLogger(get_logger_name(__file__))
Expand All @@ -40,54 +42,6 @@ def unroll_files(input_files):
yield manifest


def extract_proof_only(lean_code: str) -> str:
lines = lean_code.strip().splitlines()
if not lines:
return ""

header_start_pattern = re.compile(r"^\s*(theorem|example)\b")
header_start_idx = None

# 1. Find where the theorem starts
for i, line in enumerate(lines):
if header_start_pattern.match(line):
header_start_idx = i
break

if header_start_idx is None:
return lean_code.strip()

# 2. Find where ':=' occurs, starting from the header
header_end_idx = None
for i in range(header_start_idx, len(lines)):
if ":=" in lines[i]:
header_end_idx = i
break

if header_end_idx is None:
return lean_code.strip()

# 3. Extract the line after ':='
header_line, after = lines[header_end_idx].split(":=", 1)
proof_first_line = after.strip()

# 4. Collect proof lines
if proof_first_line:
proof_lines = [proof_first_line] + lines[header_end_idx + 1 :]
else:
proof_lines = lines[header_end_idx + 1 :]

# 5. Remove leading 'by' (with or without indentation)
if proof_lines:
first = proof_lines[0].lstrip()
if first == "by":
proof_lines = proof_lines[1:]
elif first.startswith("by "):
proof_lines[0] = first[3:] # Strip 'by '

return "\n".join(proof_lines).rstrip()


class Sandbox(abc.ABC):
"""Code execution sandbox.

Expand Down Expand Up @@ -250,14 +204,7 @@ async def is_proof_correct(self, pred_output, timeout=30.0):
output = await self._send_request(request, timeout)
except httpx.TimeoutException:
return "timeout"
if output["process_status"] == "completed":
stdout = output["stdout"].lower()
stderr = output["stderr"].lower()
combined = stdout + "\n" + stderr
if re.search(r"\bsorry\b", combined) is not None:
return "has_sorry"
return "completed"
return output["process_status"]
return determine_proof_status(output)

async def batch_evaluate_results(
self,
Expand All @@ -284,38 +231,20 @@ async def process_line(line_data):
if not line_dict:
return line_data

# Prepare predicted_proof based on format
if answer_format == "lean4-proof":
if not use_predicted_proof_key:
generation = clean_formal_generation(
line_dict["generation"], final_answer_key=final_answer_key, extract_code_mode=extract_code_mode
)
line_dict["predicted_proof"] = (
line_dict["header"]
+ (line_dict["formal_statement"] if restate_formal_statement else "")
+ extract_proof_only(generation)
if strip_theorem_from_proof
else generation
)
else:
if "predicted_proof" not in line_dict:
raise ValueError(
"predicted_proof key not found in the line_dict. "
"Set use_predicted_proof_key=False to re-combine"
)
elif answer_format == "lean4-statement":
if not use_predicted_proof_key:
generation = clean_formal_generation(line_dict["generation"], extract_code_mode=extract_code_mode)
header = get_lean4_header()
line_dict["predicted_proof"] = header + generation + "\n sorry"
else:
if "predicted_proof" not in line_dict:
raise ValueError(
"predicted_proof key not found in the line_dict. "
"Set use_predicted_proof_key=False to re-combine"
)
else:
raise ValueError(f"Unknown answer_format: {answer_format}")
# Prepare predicted_proof using shared utility
config = ProofBuildConfig(
final_answer_key=final_answer_key,
extract_code_mode=extract_code_mode,
restate_formal_statement=restate_formal_statement,
strip_theorem_from_proof=strip_theorem_from_proof,
)

line_dict["predicted_proof"] = prepare_predicted_proof_from_line_dict(
line_dict=line_dict,
config=config,
answer_format=answer_format,
use_predicted_proof_key=use_predicted_proof_key,
)

# Evaluate proof with concurrency control
async with semaphore:
Expand All @@ -333,9 +262,10 @@ async def process_line(line_data):
# Process lines concurrently with progress bar
print(f"Processing {input_file}...")
processed_lines = []
for line in tqdm.tqdm(lines):
result = await process_line(line.rstrip("\n"))
processed_lines.append(result)
tasks = [asyncio.create_task(process_line(line.rstrip("\n"))) for line in lines]
processed_lines = []
for coro in tqdm.tqdm(tasks, total=len(tasks)):
processed_lines.append(await coro)

# Write to temp file then replace original
temp_file = input_file + "-tmp"
Expand Down
Loading