From 9eccc58545b3592eb01fc2aa4135beee4818f651 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Thu, 4 Dec 2025 14:56:44 -0800 Subject: [PATCH 1/7] Add prover module for Lean4 theorem proving Signed-off-by: Stephen Ge --- nemo_skills/inference/factory.py | 2 + nemo_skills/inference/lean4_utils.py | 205 +++++++++ nemo_skills/inference/prover.py | 427 ++++++++++++++++++ nemo_skills/pipeline/prover.py | 333 ++++++++++++++ ...mal-proof-deepseek-prover-v2-nemotron.yaml | 29 ++ .../lean4/goedel-prover-v2-nemotron.yaml | 28 ++ .../goedel-prover-v2-refinement-nemotron.yaml | 28 ++ .../lean4/goedel-prover-v2-refinement.yaml | 24 + .../prompt/config/lean4/goedel-prover-v2.yaml | 27 ++ 9 files changed, 1103 insertions(+) create mode 100644 nemo_skills/inference/lean4_utils.py create mode 100644 nemo_skills/inference/prover.py create mode 100644 nemo_skills/pipeline/prover.py create mode 100644 nemo_skills/prompt/config/lean4/formal-proof-deepseek-prover-v2-nemotron.yaml create mode 100644 nemo_skills/prompt/config/lean4/goedel-prover-v2-nemotron.yaml create mode 100644 nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement-nemotron.yaml create mode 100644 nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement.yaml create mode 100644 nemo_skills/prompt/config/lean4/goedel-prover-v2.yaml diff --git a/nemo_skills/inference/factory.py b/nemo_skills/inference/factory.py index cd29bbd2c5..1ca46fe4ab 100644 --- a/nemo_skills/inference/factory.py +++ b/nemo_skills/inference/factory.py @@ -19,10 +19,12 @@ class GenerationType(str, Enum): generate = "generate" math_judge = "math_judge" check_contamination = "check_contamination" + prover = "prover" GENERATION_MODULE_MAP = { GenerationType.generate: "nemo_skills.inference.generate", GenerationType.math_judge: "nemo_skills.inference.llm_math_judge", GenerationType.check_contamination: "nemo_skills.inference.check_contamination", + GenerationType.prover: "nemo_skills.inference.prover", } diff --git a/nemo_skills/inference/lean4_utils.py b/nemo_skills/inference/lean4_utils.py new file mode 100644 index 0000000000..58c26495c1 --- /dev/null +++ b/nemo_skills/inference/lean4_utils.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from nemo_skills.utils import get_logger_name + +LOG = logging.getLogger(get_logger_name(__file__)) + +# ------------------------------------------------------------------------------------------------ +# The following code is adapted from https://github.com/Goedel-LM/Goedel-Prover-V2 +# ------------------------------------------------------------------------------------------------ + + +def remove_comments(text): # remove comments + # First remove all /- ... -/ blocks + text = re.sub(r"/-.*?-/", "", text, flags=re.DOTALL) + # Then remove -- comments from each line + lines = text.split("\n") + cleaned_lines = [] + for line in lines: + cleaned_line = line.split("--", 1)[0] + if cleaned_line.strip() == "": + continue + cleaned_lines.append(cleaned_line) + # Join back together and remove excessive empty lines + cleaned_text = "\n".join(cleaned_lines) + return cleaned_text.strip() + + +def move_imports_to_beginning(input_string): + lines = input_string.split("\n") + import_lines = [line for line in lines if line.startswith("import")] + other_lines = [line for line in lines if not line.startswith("import")] + return "\n".join(import_lines + other_lines) + + +def return_theorem_to_prove(text): + # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between + pattern = r"((?:theorem).*?:=\s*by\s*sorry)" + match = re.search(pattern, text, re.DOTALL) + return match.span() if match else None + + +def return_theorem_to_replace(text): + # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between + pattern = r"((?:^|\s)theorem\s+.*?:=\s*by)" + match = re.search(pattern, text, re.DOTALL) + return match.span() if match else None + + +def replace_statement_in_proof(statement, proof): + if ("apply?" in proof) or ("exact?" in proof): + return "**Error**, 'apply?' or 'exact?' is used, which is not allowed." + stats_re = remove_comments(statement) + stats_span_ = return_theorem_to_prove(stats_re) + if stats_span_ is None: + error_app = "\n".join(["\n"] + ["-- " + x for x in statement.split("\n")]) + return f"**Error**, can not find 'theorem' and ':= sorry' in {error_app}" + proof_str = remove_comments(proof) + span = return_theorem_to_replace(proof_str) + if span is None: + error_app = "\n".join(["\n"] + ["-- " + x for x in proof.split("\n")]) + return f"**Error**, can not find 'theorem' and ':=' in {error_app}" + return stats_re[: stats_span_[1]].replace("sorry", "") + proof_str[span[1] :] + + +def refine_by_sorry(text): + # Define the regular expression pattern + target_pattern = r":=\s*(?:by\s*)?(?:sorry\s*)?" + replacement = ":= by sorry" # The new text we want to insert + # We construct the pattern with two capturing groups + # (group 1: the part from 'theorem' to just before our target) + # (group 2: the target pattern itself) + combined_pattern = r"(theorem.*?)(" + target_pattern + r")" + # Find the first match + match = re.search(combined_pattern, text, re.DOTALL) + if match: + # The part of the string BEFORE the target we want to replace + # We use match.start(2) which is the start of the second group (our target) + prefix = text[: match.start(2)] + # Concatenate the prefix with the replacement to get the final, truncated string + final_text = prefix + replacement + else: + final_text = text + return final_text + + +def extract_code(inputs): + import_head = ( + "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n" + ) + pattern = r"```lean4\n(.*?)\n```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + pattern = r"```lean4\n(.*?)```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + pattern = r"```lean\n(.*?)```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + return "None" + + +def parse_error(log_string): + # Pattern to match multiline warnings + # warning_pattern = re.compile( + # r"(/lean4/my_project/.*?:\d+:\d+: warning:.*?)(?=\n/lean4/my_project|\Z)", + # re.DOTALL, + # ) + # Pattern to match multiline errors + error_pattern = re.compile( + r"(/lean4/my_project/.*?:\d+:\d+: error:.*?)(?=\n/lean4/my_project|\Z)", + re.DOTALL, + ) + # Find all warnings and errors + # warnings = warning_pattern.findall(log_string) + errors = error_pattern.findall(log_string) + pattern = re.compile(r":(\d+):(\d+):") + error_list = [] + for error in errors: + match = pattern.search(error) + error_list.append( + { + "pos": {"line": int(match.group(1)), "column": int(match.group(2))}, + "endPos": None, + "data": error.split("error:")[1], + } + ) + + return error_list + + +def get_error_str(code, errors, error_thres=True): + err_str = "" + code_lines = code.split("\n") + # token_lengths = [len(line) + 1 for line in code_lines] + error_num_thres = 8 if error_thres else len(errors) + + for i, error in enumerate(errors[:error_num_thres]): + start_line = error["pos"]["line"] - 1 + start_col = error["pos"]["column"] + if start_line >= len(code_lines): + LOG.warning( + "Error line %d exceeds code length %d. Errors: %s, Code: %s", start_line, len(code_lines), errors, code + ) + continue + if error["endPos"] is None: + end_line = start_line + end_col = len(code_lines[start_line]) + else: + end_line = error["endPos"]["line"] - 1 + end_col = error["endPos"]["column"] + + err_str += f"\nError {i + 1}:\n" + err_str += "\nCorresponding Code:\n```lean4\n" + error_code = "" + for ii in range(-4, 0): + if start_line + ii >= 0: + error_code += f"{code_lines[start_line + ii]}\n" + if start_line != end_line: + error_code += code_lines[start_line][:start_col] + "" + code_lines[start_line][start_col:] + "\n" + if not error_thres: + for j in range(start_line + 1, end_line): + error_code += f"{code_lines[j]}\n" + else: + show_line = 6 + for j in range(start_line + 1, min(end_line, start_line + show_line)): + error_code += f"{code_lines[j]}\n" + if end_line > start_line + show_line: + leading_spaces = len(code_lines[j]) - len(code_lines[j].lstrip(" ")) + error_code += "\n" + " " * leading_spaces + "... --[Truncated]-- ...\n" + error_code += code_lines[end_line][:end_col] + "" + code_lines[end_line][end_col:] + "\n" + else: + error_code += ( + code_lines[start_line][:start_col] + + "" + + code_lines[start_line][start_col:end_col] + + "" + + code_lines[start_line][end_col:] + + "\n" + ) + if end_line + 1 < len(code_lines): + error_code += f"{code_lines[end_line + 1]}\n" + err_str += error_code + err_str += "\n```\n" + err_str += f"\nError Message: {error['data']}\n" + if len(errors) > error_num_thres: + err_str += f"\n... [Omitted {len(errors) - error_num_thres} more errors] ...\n" + return err_str diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py new file mode 100644 index 0000000000..cad883b295 --- /dev/null +++ b/nemo_skills/inference/prover.py @@ -0,0 +1,427 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import sys +from copy import deepcopy +from dataclasses import asdict, is_dataclass + +import hydra + +from nemo_skills.code_execution.sandbox import get_sandbox, sandbox_params +from nemo_skills.inference.model import get_model, server_params +from nemo_skills.prompt.utils import get_prompt +from nemo_skills.utils import ( + get_help_message, + get_logger_name, + nested_dataclass, + parse_reasoning, + setup_logging, +) + +from .generate import GenerateSolutionsConfig, GenerationTask +from .lean4_utils import ( + extract_code, + get_error_str, + parse_error, + refine_by_sorry, + replace_statement_in_proof, +) + +LOG = logging.getLogger(get_logger_name(__file__)) + +reasoning_effort_list = [ + "low", + "medium", + "high", +] # This is only used for adaptive reasoning with gpt-oss models + + +@nested_dataclass(kw_only=True) +class ProverConfig(GenerateSolutionsConfig): + max_tokens: int = 40960 # model max tokens + n_pass: int = 1 # number of passes to run the prover + + # Lean 4 specific parameters + nemotron_refinement: bool = False # whether to use single-turn nemotron-style refinement + refinement: bool = False # whether to refine the code + refinement_max_turns: int = 2 # maximum number of turns for refinement + refinement_prompt_config: str | None = None # prompt for multi-turn refinement feedback + # prompt for single-turn nemotron refinement (used when nemotron_refinement=True) + nemotron_refinement_prompt_config: str | None = None + adaptive_reasoning: bool = False # whether to adapt the reasoning effort + parse_generation: bool = False # whether to parse the generation + remove_cot: bool = False # whether to remove the cot from the generation + # whether to delete the wrong turns from the generation + delete_wrong_turns: bool = False + + def _post_init_validate_params(self): + """Validate that certain parameters are restricted to certain values""" + if self.prompt_format not in ["ns", "openai"]: + raise ValueError(f"prompt_format must be either 'ns' or 'openai', got '{self.prompt_format}'") + + if self.prompt_format == "openai": + assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'" + else: + assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'" + for param, default_value in self._get_disallowed_params(): + if getattr(self, param) != default_value: + raise ValueError(f"{param} must be {default_value}") + + if self.n_pass > 32: + raise ValueError("Please consider using num_random_seeds instead") + + +cs = hydra.core.config_store.ConfigStore.instance() +cs.store(name="base_prover_config", node=ProverConfig) + + +class ProverTask(GenerationTask): + def __init__(self, cfg: ProverConfig): + """ + Class that represents a generation task. It implements a template of steps to generate solutions using LLMs. + Individual functions can be overriden to customize the behavior of the generation task. + + Args: + cfg: GenerateSolutionsConfig object with the configuration parameters or subclass. + """ + super().__init__(cfg) + if self.cfg.refinement: + self.setup_refine_prompt() + + if self.cfg.delete_wrong_turns: + assert self.cfg.remove_cot, "remove_cot is required when delete_wrong_turns is enabled" + + def log_example_prompt(self, data): + return + + def setup_llm(self): + if self.cfg.code_execution: + raise ValueError("Code execution is not supported for prover") + sandbox = get_sandbox(**self.cfg.sandbox) if self.cfg.sandbox is not None else None + server = deepcopy(self.cfg.server) + server["server_type"] = "autoformalization" + llm = get_model(**server, sandbox=sandbox) + return llm + + def setup_prompt(self): + if self.cfg.prompt_format == "openai": + return None + prompt = get_prompt( + prompt_config=self.cfg.prompt_config, + tokenizer=self.tokenizer, + code_tags=self.cfg.code_tags, + examples_type=self.cfg.examples_type, + system_message=self.cfg.system_message, + ) + LOG.info("Prompt used: %s", prompt) + return prompt + + def setup_refine_prompt(self): + assert self.cfg.refinement_prompt_config is not None, ( + "refinement_prompt_config is required when refinement is enabled. Please set refinement=False to disable refinement." + ) + self.refine_prompt = get_prompt(self.cfg.refinement_prompt_config) + + if self.cfg.nemotron_refinement: + assert self.cfg.nemotron_refinement_prompt_config is not None, ( + "nemotron_refinement_prompt_config is required when nemotron_refinement is enabled." + ) + self.nemotron_refine_prompt = get_prompt(self.cfg.nemotron_refinement_prompt_config) + + # with adaptive reasoning + async def _generate_single_completion(self, prompt: list[str], **kwargs): + if is_dataclass(self.cfg.inference): + inference_params = asdict(self.cfg.inference) + else: + # Already a dict from Hydra + inference_params = dict(self.cfg.inference) + generation_params = { + "prompt": prompt, + "stop_phrases": [self.cfg.stop_phrase] if self.cfg.stop_phrase else None, + **inference_params, + **self.extra_generate_params, + } + for key, value in kwargs.items(): + generation_params[key] = value + generation = await self.llm.generate_async(**generation_params) + if self.cfg.adaptive_reasoning: + assert generation_params["extra_body"].get("reasoning_effort", None) is not None, ( + "reasoning_effort is required when adaptive_reasoning is enabled" + ) + reasoning_effort_index = reasoning_effort_list.index( + generation_params["extra_body"].get("reasoning_effort", None) + ) + while len(generation["generation"]) == 0 and reasoning_effort_index > 0: + LOG.info( + "Reasoning effort is too high, reducing to %s", reasoning_effort_list[reasoning_effort_index - 1] + ) + reasoning_effort_index = reasoning_effort_index - 1 + generation_params["extra_body"]["reasoning_effort"] = reasoning_effort_list[reasoning_effort_index] + generation = await self.llm.generate_async(**generation_params) + if self.cfg.parse_generation: + parse_reasoning( + generation, + self.cfg.generation_key, + self.cfg.end_reasoning_string, + ) + return generation + + # factor out this part so it won't become a bottleneck. + async def _extract_and_replace_code(self, formal_statement, generation): + code = extract_code(generation) + full_code = replace_statement_in_proof(formal_statement, code) + return code, full_code + + def _transform_for_nemotron_refinement(self, proof_attempt: str, error_message: str) -> list[dict]: + """Transform multi-turn refinement into single-turn nemotron-style prompt.""" + return self.nemotron_refine_prompt.fill( + { + "proof_attempt": proof_attempt, + "error_message": error_message, + } + ) + + async def _single_data_point_generate(self, data_point, data): + formal_statement = ( + (data_point["header"].strip() + "\n") + + data_point["informal_prefix"].strip() + + ("\n" + data_point["formal_statement"].strip()) + ) + formal_statement = refine_by_sorry(formal_statement) + prompt_turn_list = self.prompt.fill({"problem": formal_statement.strip()}) + + full_prompt_turn_list = deepcopy( + prompt_turn_list + ) # We need to get a full copy of the prompt turn list for the final result in case remove_cot is enabled. This is only used to generate SFT data. + prompt_turn_list_list = [] # We need to store the prompt turn list for each turn for the final result in case delete_wrong_turns is enabled. This is only used to generate SFT data. + base_prompt_turn_list = deepcopy(prompt_turn_list) + + code_list = [] + results_dict_list = [] + assert isinstance(prompt_turn_list, list), "prompt_turn_list should be a list" + + success = False + turn_idx = 0 + last_proof_attempt = None # Track for nemotron refinement + last_error_message = None # Track for nemotron refinement + for turn_idx in range(self.cfg.refinement_max_turns): + results_dict = {} # everything will be stored in this dict + if turn_idx != 0 and self.cfg.nemotron_refinement and last_proof_attempt and last_error_message: + prepared_conversation = self._transform_for_nemotron_refinement(last_proof_attempt, last_error_message) + else: + prepared_conversation = prompt_turn_list + prefix_tokens = self.llm.tokenizer.apply_chat_template( + prepared_conversation, tokenize=True, add_generation_prompt=True + ) + num_tokens_prefix = len(prefix_tokens) + prefix = self.llm.tokenizer.apply_chat_template( + prepared_conversation, tokenize=False, add_generation_prompt=True + ) + # We need to check if the prefix is too long, if it is, we need to break the loop + if num_tokens_prefix > self.cfg.max_tokens: + break + + generation = await self._generate_single_completion( + prefix, + tokens_to_generate=min( + self.cfg.max_tokens - num_tokens_prefix, + self.cfg.inference.tokens_to_generate, + ), + ) + + new_prompt_turn_list = deepcopy(prompt_turn_list) + new_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + + prompt_turn_list_list.append( + new_prompt_turn_list + ) # This stores the latest turn list after each generation. + + code, full_code = await self._extract_and_replace_code(formal_statement, generation["generation"]) + last_proof_attempt = generation["generation"] # Track for nemotron refinement + code_list.append(full_code) + results_dict["code"] = code # We keep track of the uncleaned code. + if self.cfg.remove_cot and not ( + code == "None" or "**Error**" in full_code + ): # check if successfully parse the code. We do not want to delete the turn if there is a parsing error. + if self.cfg.delete_wrong_turns: + prompt_turn_list = deepcopy(base_prompt_turn_list) + [ + { + "role": "assistant", + "content": f"```lean4\n{full_code.strip()}\n```", + } + ] # only keep the latest turn + else: + prompt_turn_list += [ + { + "role": "assistant", + "content": f"```lean4\n{full_code.strip()}\n```", + } + ] + full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + else: + prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + + if code == "None" or "**Error**" in full_code: + if code == "None": + execution_result = { + "process_status": "failed", + "stderr": "", + "stdout": "Parsing error. Cannot parse the code from output. Please try again and write the code in the format of ```lean4\n\n```", + } + elif "**Error**" in full_code: + execution_result = { + "process_status": "failed", + "stderr": "", + "stdout": full_code, + } + else: + execution_result = { + "process_status": "failed", + "stderr": "", + "stdout": "Unknown error when parsing code.", + } + results_dict["execution_result"] = execution_result + results_dict["success"] = False + last_error_message = execution_result["stdout"] # Track for nemotron refinement + feedback = self.refine_prompt.fill({"error_message": last_error_message}) + results_dict["feedback"] = feedback[0]["content"] + else: + execution_result = await self.llm.sandbox.execute_lean4_code( + full_code, timeout=600.0, max_output_characters=1000000 + ) + results_dict["execution_result"] = execution_result + if isinstance(execution_result, dict): + if ( + execution_result["process_status"] == "completed" + and "sorry" not in execution_result["stdout"] + and "failed" not in execution_result["stdout"] + ): + results_dict["success"] = True + else: + error_list = parse_error(execution_result["stdout"]) + error_message = get_error_str(full_code, error_list, error_thres=True) + # checking for sorry + if execution_result["process_status"] == "completed": + stdout = execution_result["stdout"].lower() + stderr = execution_result["stderr"].lower() + combined = stdout + "\n" + stderr + if re.search(r"\bsorry\b", combined) is not None: + error_message += "\nThe code contains 'sorry', which means the proof is incomplete." + if error_message.strip() == "": # something in stderr indicating failure + error_message = execution_result["stderr"][:1000] + if len(execution_result["stderr"]) > 1000: + error_message += "... (truncated)" + + last_error_message = ( + "We use to signal the position of the error. \n" + error_message + ) + feedback = self.refine_prompt.fill({"error_message": last_error_message}) + results_dict["feedback"] = feedback[0]["content"] + results_dict["success"] = False + # This is only used for the case when the code execution timed out. + elif isinstance(execution_result, str): + execution_result = { + "process_status": "failed", + "stderr": "", + "stdout": execution_result, + } + results_dict["success"] = False + last_error_message = ( + "The compilation timed out. There might be a heavy computation in the code or an endless loop." + ) + feedback = self.refine_prompt.fill({"error_message": last_error_message}) + results_dict["feedback"] = feedback[0]["content"] + else: + raise ValueError(f"Unknown execution result type: {type(execution_result)}") + + results_dict_list.append(results_dict) + + if results_dict["success"]: + # This is the case when the code execution is successful. The theorem is proved. + break + else: + if self.cfg.refinement and turn_idx < self.cfg.refinement_max_turns - 1: + prompt_turn_list += feedback + full_prompt_turn_list += feedback + else: + # Proving attempt failed. + break + + if len(results_dict_list) > 0 and results_dict_list[-1]["success"]: + success = True + + # Usually only need prompt_turn_list for standard SFT, full_prompt_turn_list for SFT with remove_cot enabled, prompt_turn_list_list for SFT with delete_wrong_turns enabled. + return { + "code_list": code_list, + "results_dict_list": results_dict_list, + "prompt_turn_list": prompt_turn_list, + "turn_idx": turn_idx, + "success": success, + "full_prompt_turn_list": full_prompt_turn_list, + "prompt_turn_list_list": prompt_turn_list_list, + } + + async def pass_at_N(self, data_point, data, N=None): + if N is None: + N = self.cfg.n_pass + + new_results_dict = {"success": False} + for i in range(N): + results_dict = await self._single_data_point_generate(data_point, data) + + if results_dict["success"]: + new_results_dict["success"] = True + break + + new_results_dict["results_dict_list"] = results_dict + new_results_dict["n_pass"] = i + 1 + + return new_results_dict + + async def process_single_datapoint(self, data_point, all_data): + result = await self.pass_at_N(data_point, all_data) + result_dict = {"generation": result} + + return result_dict + + +GENERATION_TASK_CLASS = ProverTask + + +# Update the hydra main to use the class method +@hydra.main(version_base=None, config_name="base_prover_config") +def generate(cfg: ProverConfig): + cfg = ProverConfig(_init_nested=True, **cfg) + LOG.info("Config used: %s", cfg) + + task = ProverTask(cfg) + task.generate() + + +HELP_MESSAGE = get_help_message( + ProverConfig, + server_params=server_params(), + sandbox_params=sandbox_params(), +) + + +if __name__ == "__main__": + if "--help" in sys.argv or "-h" in sys.argv: + print(HELP_MESSAGE) + else: + setup_logging() + generate() diff --git a/nemo_skills/pipeline/prover.py b/nemo_skills/pipeline/prover.py new file mode 100644 index 0000000000..a70a199f94 --- /dev/null +++ b/nemo_skills/pipeline/prover.py @@ -0,0 +1,333 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import os +from typing import List + +import typer + +import nemo_skills.pipeline.utils as pipeline_utils +from nemo_skills.dataset.utils import import_from_path +from nemo_skills.inference import GENERATION_MODULE_MAP, GenerationType +from nemo_skills.pipeline.app import app, typer_unpacker +from nemo_skills.utils import ( + compute_chunk_ids, + get_logger_name, + setup_logging, + str_ids_to_list, +) + +LOG = logging.getLogger(get_logger_name(__file__)) + +# TODO: add num_jobs here for consistency with eval? + + +@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) +@typer_unpacker +def generate( + ctx: typer.Context, + cluster: str = typer.Option( + None, + help="One of the configs inside config_dir or NEMO_SKILLS_CONFIG_DIR or ./cluster_configs. " + "Can also use NEMO_SKILLS_CONFIG instead of specifying as argument.", + ), + input_file: str = typer.Option( + None, help="Path to the input data file. Can either specify input_file or input_dir, but not both. " + ), + input_dir: str = typer.Option( + None, + help="Path to the input data directory. Can either specify input_file or input_dir, but not both. " + "If input_file is not provided, will use output-rs{{seed}}.jsonl inside input_dir as input_files. " + "In this case, the random seed parameter is used both for input and for output files, which " + "means it's a 1-1 mapping (not 1-num_random_seeds as in the case of input_file).", + ), + output_dir: str = typer.Option(..., help="Where to put results"), + expname: str = typer.Option("generate", help="Nemo run experiment name"), + generation_type: GenerationType | None = typer.Option(None, help="Type of generation to perform"), + generation_module: str = typer.Option( + None, + help="Path to the generation module to use. " + "If not specified, will use the registered generation module for the " + "generation type (which is required in this case).", + ), + model: str = typer.Option(None, help="Path to the model or model name in API"), + server_address: str = typer.Option( + None, help="Use ip:port for self-hosted models or the API url if using model providers" + ), + server_type: pipeline_utils.SupportedServers = typer.Option(..., help="Type of server to use"), + server_gpus: int = typer.Option(None, help="Number of GPUs to use if hosting the model"), + server_nodes: int = typer.Option(1, help="Number of nodes required for hosting LLM server"), + server_args: str = typer.Option("", help="Any extra arguments to pass to the server"), + server_entrypoint: str = typer.Option( + None, + help="Path to the entrypoint of the server. " + "If not specified, will use the default entrypoint for the server type.", + ), + server_container: str = typer.Option( + None, help="Override container image for the hosted server (if server_gpus is set)" + ), + dependent_jobs: int = typer.Option(0, help="Specify this to launch that number of dependent jobs"), + mount_paths: str = typer.Option(None, help="Comma separated list of paths to mount on the remote machine"), + num_random_seeds: int = typer.Option( + None, help="Specify if want to run many generations with high temperature for the same input" + ), + random_seeds: str = typer.Option( + None, + help="List of random seeds to use for generation. Separate with , or .. to specify range. " + "Can provide a list directly when using through Python", + ), + starting_seed: int = typer.Option(0, help="Starting seed for random sampling"), + num_chunks: int = typer.Option( + None, + help="Number of chunks to split the dataset into. If None, will not chunk the dataset.", + ), + chunk_ids: str = typer.Option( + None, + help="List of explicit chunk ids to run. Separate with , or .. to specify range. " + "Can provide a list directly when using through Python", + ), + preprocess_cmd: str = typer.Option(None, help="Command to run before generation"), + postprocess_cmd: str = typer.Option(None, help="Command to run after generation"), + partition: str = typer.Option( + None, help="Can specify if need interactive jobs or a specific non-default partition" + ), + time_min: str = typer.Option(None, help="If specified, will use as a time-min slurm parameter"), + eval_args: str = typer.Option( + None, help="Specify if need to run nemo_skills/evaluation/evaluate_results.py on the generation outputs" + ), + run_after: List[str] = typer.Option( + None, help="Can specify a list of expnames that need to be completed before this one starts" + ), + reuse_code: bool = typer.Option( + True, + help="If True, will reuse the code from the provided experiment. " + "If you use it from Python, by default the code will be re-used from " + "the last submitted experiment in the current Python session, so set to False to disable " + "(or provide reuse_code_exp to override).", + ), + reuse_code_exp: str = typer.Option( + None, + help="If specified, will reuse the code from this experiment. " + "Can provide an experiment name or an experiment object if running from code.", + ), + config_dir: str = typer.Option(None, help="Can customize where we search for cluster configs"), + log_dir: str = typer.Option(None, help="Can specify a custom location for slurm logs."), + exclusive: bool = typer.Option(False, help="If set will add exclusive flag to the slurm job."), + rerun_done: bool = typer.Option( + False, help="If True, will re-run jobs even if a corresponding '.done' file already exists" + ), + with_sandbox: bool = typer.Option(False, help="If True, will start a sandbox container alongside this job"), + check_mounted_paths: bool = typer.Option(False, help="Check if mounted paths are available on the remote machine"), + log_samples: bool = typer.Option( + False, + help="If True, will log random samples from the output files to wandb. " + "Requires WANDB_API_KEY to be set in the environment. " + "Use wandb_name/wandb_group/wandb_project to specify where to log.", + ), + wandb_name: str = typer.Option( + None, + help="Name of the wandb group to sync samples to. If not specified, but log_samples=True, will use expname.", + ), + wandb_group: str = typer.Option(None, help="Name of the wandb group to sync samples to."), + wandb_project: str = typer.Option( + "nemo-skills", + help="Name of the wandb project to sync samples to.", + ), + installation_command: str | None = typer.Option( + None, + help="An installation command to run before main job. Only affects main task (not server or sandbox). " + "You can use an arbitrary command here and we will run it on a single rank for each node. " + "E.g. 'pip install my_package'", + ), + skip_hf_home_check: bool = typer.Option( + False, + help="If True, skip checking that HF_HOME env var is defined in the cluster config.", + ), + dry_run: bool = typer.Option(False, help="If True, will not run the job, but will validate all arguments."), + _reuse_exp: str = typer.Option(None, help="Internal option to reuse an experiment object.", hidden=True), + _task_dependencies: List[str] = typer.Option( + None, help="Internal option to specify task dependencies.", hidden=True + ), +): + """Generate LLM completions for a given input file. + + Run `python -m nemo_skills.inference.generate --help` for other supported arguments + (need to be prefixed with ++, since we use Hydra for that script). + """ + setup_logging(disable_hydra_logs=False, use_rich=True) + extra_arguments = f"{' '.join(ctx.args)}" + LOG.info("Starting generation job") + LOG.info("Extra arguments that will be passed to the underlying script: %s", extra_arguments) + + try: + server_type = server_type.value + except AttributeError: + pass + + if log_samples: + wandb_parameters = { + "name": wandb_name or expname, + "project": wandb_project, + "group": wandb_group, + } + else: + wandb_parameters = None + + get_random_port = pipeline_utils.should_get_random_port(server_gpus, exclusive, server_type) + + if random_seeds and num_random_seeds: + raise ValueError("Cannot specify both random_seeds and num_random_seeds") + if num_random_seeds: + random_seeds = list(range(starting_seed, starting_seed + num_random_seeds)) + if isinstance(random_seeds, str): + random_seeds = str_ids_to_list(random_seeds) + + if num_chunks: + chunk_ids = compute_chunk_ids(chunk_ids, num_chunks) + if chunk_ids is None: + chunk_ids = [None] + + # Prepare cluster config and mount paths + cluster_config = pipeline_utils.get_cluster_config(cluster, config_dir) + cluster_config = pipeline_utils.resolve_mount_paths( + cluster_config, mount_paths, create_remote_dir=check_mounted_paths + ) + + if not log_dir: + log_dir = f"{output_dir}/generation-logs" + + output_dir, log_dir = pipeline_utils.check_mounts( + cluster_config, + log_dir=log_dir, + mount_map={output_dir: None}, + check_mounted_paths=check_mounted_paths, + ) + + original_server_address = server_address + + if generation_module is not None and generation_type is not None: + raise ValueError("Cannot specify both generation_module and generation_type. ") + if generation_module is None: + generation_module = GENERATION_MODULE_MAP[generation_type or GenerationType.generate] + + if os.sep in generation_module: + generation_task = import_from_path(generation_module) + else: + generation_task = importlib.import_module(generation_module) + if not hasattr(generation_task, "GENERATION_TASK_CLASS"): + raise ValueError( + f"Module {generation_module} does not have a GENERATION_TASK_CLASS attribute. " + "Please provide a valid generation module." + ) + generation_task = generation_task.GENERATION_TASK_CLASS + extra_arguments = f"{generation_task.get_generation_default_args()} {extra_arguments}" + extra_arguments_original = extra_arguments + + # Treat no random seeds as a single None seed to unify the code paths + if not random_seeds: + random_seeds = [None] + + remaining_jobs = pipeline_utils.get_remaining_jobs( + cluster_config=cluster_config, + output_dir=output_dir, + random_seeds=random_seeds, + chunk_ids=chunk_ids, + rerun_done=rerun_done, + ) + has_tasks = False + all_tasks = [] + if _task_dependencies is None: + _task_dependencies = [] + with pipeline_utils.get_exp(expname, cluster_config, _reuse_exp) as exp: + for seed_idx, (seed, chunk_ids) in enumerate(remaining_jobs.items()): + if wandb_parameters: + # no need for chunks as it will run after merging + wandb_parameters["samples_file"] = pipeline_utils.get_chunked_rs_filename( + output_dir, + random_seed=seed, + chunk_id=None, + ) + for chunk_id in chunk_ids: + has_tasks = True + server_config, server_address, extra_arguments = pipeline_utils.configure_client( + model=model, + server_type=server_type, + server_address=original_server_address, + server_gpus=server_gpus, + server_nodes=server_nodes, + server_args=server_args, + server_entrypoint=server_entrypoint, + server_container=server_container, + extra_arguments=extra_arguments_original, + get_random_port=get_random_port, + ) + cmd = pipeline_utils.get_generation_cmd( + input_file=input_file, + input_dir=input_dir, + random_seed=seed, + output_dir=output_dir, + extra_arguments=extra_arguments, + eval_args=eval_args, + chunk_id=chunk_id, + num_chunks=num_chunks, + preprocess_cmd=preprocess_cmd, + postprocess_cmd=postprocess_cmd, + wandb_parameters=wandb_parameters if seed_idx == 0 else None, + script=generation_module, + ) + prev_tasks = _task_dependencies + for _ in range(dependent_jobs + 1): + task_name = f"{expname}-rs{seed}" if seed is not None else expname + if chunk_id is not None: + task_name += f"-chunk{chunk_id}" + new_task = pipeline_utils.add_task( + exp, + cmd=pipeline_utils.wrap_python_path(cmd=cmd), + task_name=task_name, + log_dir=log_dir, + container=cluster_config["containers"]["nemo-skills"], + cluster_config=cluster_config, + partition=partition, + time_min=time_min, + server_config=server_config, + with_sandbox=with_sandbox, + sandbox_port=None if get_random_port else 6000, + run_after=run_after, + reuse_code=reuse_code, + reuse_code_exp=reuse_code_exp, + task_dependencies=( + prev_tasks if cluster_config["executor"] == "slurm" else all_tasks + _task_dependencies + ), + get_server_command=generation_task.get_server_command_fn(), + slurm_kwargs={"exclusive": exclusive} if exclusive else None, + installation_command=installation_command, + skip_hf_home_check=skip_hf_home_check, + ) + prev_tasks = [new_task] + all_tasks.append(new_task) + if has_tasks and not _reuse_exp: # if we are reusing an experiment, the tasks will run from there + pipeline_utils.run_exp(exp, cluster_config, dry_run=dry_run) + + if _reuse_exp: + return all_tasks + else: + if has_tasks: + return exp + return None + + +if __name__ == "__main__": + typer.main.get_command_name = lambda name: name + app() diff --git a/nemo_skills/prompt/config/lean4/formal-proof-deepseek-prover-v2-nemotron.yaml b/nemo_skills/prompt/config/lean4/formal-proof-deepseek-prover-v2-nemotron.yaml new file mode 100644 index 0000000000..54da2461e8 --- /dev/null +++ b/nemo_skills/prompt/config/lean4/formal-proof-deepseek-prover-v2-nemotron.yaml @@ -0,0 +1,29 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration for proving formal theorems in Lean 4. +# This file is tailored for tasks that involve constructing and verifying proofs +# of theorems within the Lean 4 formal system. + +user: |- + Complete the following Lean 4 code: + + ```lean4 + {header}{informal_prefix}{formal_statement} + sorry + ``` + + First, think through your solution step-by-step. Provide a detailed proof plan outlining the main proof steps and strategies. The plan should highlight key ideas, intermediate lemmas, and proof structures that will guide the construction of the final formal proof. + + Then provide your final answer. Your final answer must be a single, complete Lean 4 markdown code block containing the completed theorem. Do NOT include any text or explanation before or after the code block. Begin with ```lean4 and end with ```. diff --git a/nemo_skills/prompt/config/lean4/goedel-prover-v2-nemotron.yaml b/nemo_skills/prompt/config/lean4/goedel-prover-v2-nemotron.yaml new file mode 100644 index 0000000000..db6adac3a5 --- /dev/null +++ b/nemo_skills/prompt/config/lean4/goedel-prover-v2-nemotron.yaml @@ -0,0 +1,28 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration for proving formal theorems in Lean 4. +# This file is tailored for tasks that involve constructing and verifying proofs +# of theorems within the Lean 4 formal system. + +user: |- + Complete the following Lean 4 code: + + ```lean4 + {problem} + ``` + + First, think through your solution step-by-step. Provide a detailed proof plan outlining the main proof steps and strategies. The plan should highlight key ideas, intermediate lemmas, and proof structures that will guide the construction of the final formal proof. + + Then provide your final answer. Your final answer must be a single, complete Lean 4 markdown code block containing the completed theorem. Do NOT include any text or explanation before or after the code block. Begin with ```lean4 and end with ```. diff --git a/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement-nemotron.yaml b/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement-nemotron.yaml new file mode 100644 index 0000000000..960e4dbc8a --- /dev/null +++ b/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement-nemotron.yaml @@ -0,0 +1,28 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration for proving formal theorems in Lean 4. +# This file is tailored for tasks that involve constructing and verifying proofs +# of theorems within the Lean 4 formal system. + +user: |- + Here is a proof attempt for the following theorem in Lean4. + + {proof_attempt} + + The proof is not correct. Following is the compilation error message: + + {error_message} + + Your task is to fix this proof. Before producing the Lean 4 code to formally prove the given theorem, do a detailed analysis of the error message. Your final answer must be a single, complete Lean 4 markdown code block containing the completed theorem. Do NOT include any text or explanation before or after the code block. Begin with ```lean4 and end with ```. diff --git a/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement.yaml b/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement.yaml new file mode 100644 index 0000000000..b46fa0a3cc --- /dev/null +++ b/nemo_skills/prompt/config/lean4/goedel-prover-v2-refinement.yaml @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration for proving formal theorems in Lean 4. +# This file is tailored for tasks that involve constructing and verifying proofs +# of theorems within the Lean 4 formal system. + +user: |- + The proof is not correct. Following is the compilation error message: + + {error_message} + + Before producing the Lean 4 code to formally prove the given theorem, provide a detailed analysis of the error message. diff --git a/nemo_skills/prompt/config/lean4/goedel-prover-v2.yaml b/nemo_skills/prompt/config/lean4/goedel-prover-v2.yaml new file mode 100644 index 0000000000..86d18a33f1 --- /dev/null +++ b/nemo_skills/prompt/config/lean4/goedel-prover-v2.yaml @@ -0,0 +1,27 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration for proving formal theorems in Lean 4. +# This file is tailored for tasks that involve constructing and verifying proofs +# of theorems within the Lean 4 formal system. + +user: |- + Complete the following Lean 4 code: + + ```lean4 + {problem} + ``` + + Before producing the Lean 4 code to formally prove the given theorem, provide a detailed proof plan outlining the main proof steps and strategies. + The plan should highlight key ideas, intermediate lemmas, and proof structures that will guide the construction of the final formal proof. From 3a545fc31ba998bd206c3ae54dab97b834058a59 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Thu, 4 Dec 2025 15:10:48 -0800 Subject: [PATCH 2/7] remove dependence on autoformalization Signed-off-by: Stephen Ge --- nemo_skills/inference/prover.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index cad883b295..24525b5945 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -110,10 +110,9 @@ def log_example_prompt(self, data): def setup_llm(self): if self.cfg.code_execution: raise ValueError("Code execution is not supported for prover") - sandbox = get_sandbox(**self.cfg.sandbox) if self.cfg.sandbox is not None else None - server = deepcopy(self.cfg.server) - server["server_type"] = "autoformalization" - llm = get_model(**server, sandbox=sandbox) + # Store sandbox directly on self for Lean4 code execution + self.sandbox = get_sandbox(**self.cfg.sandbox) if self.cfg.sandbox is not None else None + llm = get_model(**self.cfg.server, tokenizer=self.tokenizer) return llm def setup_prompt(self): @@ -300,7 +299,7 @@ async def _single_data_point_generate(self, data_point, data): feedback = self.refine_prompt.fill({"error_message": last_error_message}) results_dict["feedback"] = feedback[0]["content"] else: - execution_result = await self.llm.sandbox.execute_lean4_code( + execution_result = await self.sandbox.execute_lean4_code( full_code, timeout=600.0, max_output_characters=1000000 ) results_dict["execution_result"] = execution_result From 605a8980c6c53573113b4e5bd6376dc1973592b5 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Thu, 4 Dec 2025 15:24:13 -0800 Subject: [PATCH 3/7] tokenizer Signed-off-by: Stephen Ge --- nemo_skills/inference/prover.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 24525b5945..24e7bb0a2a 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -19,6 +19,7 @@ from dataclasses import asdict, is_dataclass import hydra +from transformers import AutoTokenizer from nemo_skills.code_execution.sandbox import get_sandbox, sandbox_params from nemo_skills.inference.model import get_model, server_params @@ -98,6 +99,11 @@ def __init__(self, cfg: ProverConfig): cfg: GenerateSolutionsConfig object with the configuration parameters or subclass. """ super().__init__(cfg) + + # Initialize tokenizer for chat template application + tokenizer_path = self.cfg.tokenizer or self.cfg.server.get("model") + self.hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + if self.cfg.refinement: self.setup_refine_prompt() @@ -222,11 +228,11 @@ async def _single_data_point_generate(self, data_point, data): prepared_conversation = self._transform_for_nemotron_refinement(last_proof_attempt, last_error_message) else: prepared_conversation = prompt_turn_list - prefix_tokens = self.llm.tokenizer.apply_chat_template( + prefix_tokens = self.hf_tokenizer.apply_chat_template( prepared_conversation, tokenize=True, add_generation_prompt=True ) num_tokens_prefix = len(prefix_tokens) - prefix = self.llm.tokenizer.apply_chat_template( + prefix = self.hf_tokenizer.apply_chat_template( prepared_conversation, tokenize=False, add_generation_prompt=True ) # We need to check if the prefix is too long, if it is, we need to break the loop From b5d527be6c8dcc27bf4e35a86b6fb2e75badce61 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Thu, 4 Dec 2025 15:54:07 -0800 Subject: [PATCH 4/7] fixes Signed-off-by: Stephen Ge --- nemo_skills/inference/prover.py | 74 ++++++++++++++++----------------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 24e7bb0a2a..d546e8ad47 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -23,6 +23,7 @@ from nemo_skills.code_execution.sandbox import get_sandbox, sandbox_params from nemo_skills.inference.model import get_model, server_params +from nemo_skills.inference.model.base import EndpointType from nemo_skills.prompt.utils import get_prompt from nemo_skills.utils import ( get_help_message, @@ -147,7 +148,7 @@ def setup_refine_prompt(self): self.nemotron_refine_prompt = get_prompt(self.cfg.nemotron_refinement_prompt_config) # with adaptive reasoning - async def _generate_single_completion(self, prompt: list[str], **kwargs): + async def _generate_single_completion(self, prompt: str, **kwargs): if is_dataclass(self.cfg.inference): inference_params = asdict(self.cfg.inference) else: @@ -159,6 +160,8 @@ async def _generate_single_completion(self, prompt: list[str], **kwargs): **inference_params, **self.extra_generate_params, } + # Override endpoint_type to text since we already applied the chat template + generation_params["endpoint_type"] = EndpointType.text for key, value in kwargs.items(): generation_params[key] = value generation = await self.llm.generate_async(**generation_params) @@ -305,53 +308,46 @@ async def _single_data_point_generate(self, data_point, data): feedback = self.refine_prompt.fill({"error_message": last_error_message}) results_dict["feedback"] = feedback[0]["content"] else: - execution_result = await self.sandbox.execute_lean4_code( - full_code, timeout=600.0, max_output_characters=1000000 + # execute_code returns (result_dict, session_id) tuple + execution_result, _ = await self.sandbox.execute_code( + full_code, language="lean4", timeout=600.0, max_output_characters=1000000 ) results_dict["execution_result"] = execution_result - if isinstance(execution_result, dict): - if ( - execution_result["process_status"] == "completed" - and "sorry" not in execution_result["stdout"] - and "failed" not in execution_result["stdout"] - ): - results_dict["success"] = True - else: - error_list = parse_error(execution_result["stdout"]) - error_message = get_error_str(full_code, error_list, error_thres=True) - # checking for sorry - if execution_result["process_status"] == "completed": - stdout = execution_result["stdout"].lower() - stderr = execution_result["stderr"].lower() - combined = stdout + "\n" + stderr - if re.search(r"\bsorry\b", combined) is not None: - error_message += "\nThe code contains 'sorry', which means the proof is incomplete." - if error_message.strip() == "": # something in stderr indicating failure - error_message = execution_result["stderr"][:1000] - if len(execution_result["stderr"]) > 1000: - error_message += "... (truncated)" - - last_error_message = ( - "We use to signal the position of the error. \n" + error_message - ) - feedback = self.refine_prompt.fill({"error_message": last_error_message}) - results_dict["feedback"] = feedback[0]["content"] - results_dict["success"] = False - # This is only used for the case when the code execution timed out. - elif isinstance(execution_result, str): - execution_result = { - "process_status": "failed", - "stderr": "", - "stdout": execution_result, - } + # Handle timeout (now indicated by process_status in the dict) + if execution_result.get("process_status") == "timeout": results_dict["success"] = False last_error_message = ( "The compilation timed out. There might be a heavy computation in the code or an endless loop." ) feedback = self.refine_prompt.fill({"error_message": last_error_message}) results_dict["feedback"] = feedback[0]["content"] + elif ( + execution_result["process_status"] == "completed" + and "sorry" not in execution_result["stdout"] + and "failed" not in execution_result["stdout"] + ): + results_dict["success"] = True else: - raise ValueError(f"Unknown execution result type: {type(execution_result)}") + error_list = parse_error(execution_result["stdout"]) + error_message = get_error_str(full_code, error_list, error_thres=True) + # checking for sorry + if execution_result["process_status"] == "completed": + stdout = execution_result["stdout"].lower() + stderr = execution_result["stderr"].lower() + combined = stdout + "\n" + stderr + if re.search(r"\bsorry\b", combined) is not None: + error_message += "\nThe code contains 'sorry', which means the proof is incomplete." + if error_message.strip() == "": # something in stderr indicating failure + error_message = execution_result["stderr"][:1000] + if len(execution_result["stderr"]) > 1000: + error_message += "... (truncated)" + + last_error_message = ( + "We use to signal the position of the error. \n" + error_message + ) + feedback = self.refine_prompt.fill({"error_message": last_error_message}) + results_dict["feedback"] = feedback[0]["content"] + results_dict["success"] = False results_dict_list.append(results_dict) From d8cc81b4bea49075cc677145c99540ac3bba9e33 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Mon, 8 Dec 2025 06:29:24 -0800 Subject: [PATCH 5/7] update inference.py, remove pipline/prover.py and from factory Signed-off-by: Stephen Ge --- nemo_skills/inference/factory.py | 2 - nemo_skills/inference/prover.py | 73 +++---- nemo_skills/pipeline/prover.py | 333 ------------------------------- 3 files changed, 37 insertions(+), 371 deletions(-) delete mode 100644 nemo_skills/pipeline/prover.py diff --git a/nemo_skills/inference/factory.py b/nemo_skills/inference/factory.py index 1ca46fe4ab..cd29bbd2c5 100644 --- a/nemo_skills/inference/factory.py +++ b/nemo_skills/inference/factory.py @@ -19,12 +19,10 @@ class GenerationType(str, Enum): generate = "generate" math_judge = "math_judge" check_contamination = "check_contamination" - prover = "prover" GENERATION_MODULE_MAP = { GenerationType.generate: "nemo_skills.inference.generate", GenerationType.math_judge: "nemo_skills.inference.llm_math_judge", GenerationType.check_contamination: "nemo_skills.inference.check_contamination", - GenerationType.prover: "nemo_skills.inference.prover", } diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index d546e8ad47..4310698fc9 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -71,19 +71,23 @@ class ProverConfig(GenerateSolutionsConfig): def _post_init_validate_params(self): """Validate that certain parameters are restricted to certain values""" - if self.prompt_format not in ["ns", "openai"]: - raise ValueError(f"prompt_format must be either 'ns' or 'openai', got '{self.prompt_format}'") - if self.prompt_format == "openai": - assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'" - else: - assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'" + raise ValueError( + "prompt_format='openai' is not supported for lean4_prover. Use prompt_format='ns' with a prompt_config." + ) + if self.prompt_format != "ns": + raise ValueError(f"prompt_format must be 'ns', got '{self.prompt_format}'") + + assert self.prompt_config is not None, "prompt_config is required for lean4_prover" + for param, default_value in self._get_disallowed_params(): if getattr(self, param) != default_value: raise ValueError(f"{param} must be {default_value}") if self.n_pass > 32: - raise ValueError("Please consider using num_random_seeds instead") + LOG.warning( + "n_pass=%d exceeds recommended maximum of 32. Consider using num_random_seeds instead.", self.n_pass + ) cs = hydra.core.config_store.ConfigStore.instance() @@ -122,19 +126,6 @@ def setup_llm(self): llm = get_model(**self.cfg.server, tokenizer=self.tokenizer) return llm - def setup_prompt(self): - if self.cfg.prompt_format == "openai": - return None - prompt = get_prompt( - prompt_config=self.cfg.prompt_config, - tokenizer=self.tokenizer, - code_tags=self.cfg.code_tags, - examples_type=self.cfg.examples_type, - system_message=self.cfg.system_message, - ) - LOG.info("Prompt used: %s", prompt) - return prompt - def setup_refine_prompt(self): assert self.cfg.refinement_prompt_config is not None, ( "refinement_prompt_config is required when refinement is enabled. Please set refinement=False to disable refinement." @@ -147,8 +138,8 @@ def setup_refine_prompt(self): ) self.nemotron_refine_prompt = get_prompt(self.cfg.nemotron_refinement_prompt_config) - # with adaptive reasoning async def _generate_single_completion(self, prompt: str, **kwargs): + """Generate a single completion with semaphore-controlled concurrency.""" if is_dataclass(self.cfg.inference): inference_params = asdict(self.cfg.inference) else: @@ -164,21 +155,26 @@ async def _generate_single_completion(self, prompt: str, **kwargs): generation_params["endpoint_type"] = EndpointType.text for key, value in kwargs.items(): generation_params[key] = value - generation = await self.llm.generate_async(**generation_params) - if self.cfg.adaptive_reasoning: - assert generation_params["extra_body"].get("reasoning_effort", None) is not None, ( - "reasoning_effort is required when adaptive_reasoning is enabled" - ) - reasoning_effort_index = reasoning_effort_list.index( - generation_params["extra_body"].get("reasoning_effort", None) - ) - while len(generation["generation"]) == 0 and reasoning_effort_index > 0: - LOG.info( - "Reasoning effort is too high, reducing to %s", reasoning_effort_list[reasoning_effort_index - 1] + + # Use semaphore for concurrency control (inherited from GenerationTask) + async with self.semaphore: + generation = await self.llm.generate_async(**generation_params) + if self.cfg.adaptive_reasoning: + assert generation_params["extra_body"].get("reasoning_effort", None) is not None, ( + "reasoning_effort is required when adaptive_reasoning is enabled" + ) + reasoning_effort_index = reasoning_effort_list.index( + generation_params["extra_body"].get("reasoning_effort", None) ) - reasoning_effort_index = reasoning_effort_index - 1 - generation_params["extra_body"]["reasoning_effort"] = reasoning_effort_list[reasoning_effort_index] - generation = await self.llm.generate_async(**generation_params) + while len(generation["generation"]) == 0 and reasoning_effort_index > 0: + LOG.info( + "Reasoning effort is too high, reducing to %s", + reasoning_effort_list[reasoning_effort_index - 1], + ) + reasoning_effort_index = reasoning_effort_index - 1 + generation_params["extra_body"]["reasoning_effort"] = reasoning_effort_list[reasoning_effort_index] + generation = await self.llm.generate_async(**generation_params) + if self.cfg.parse_generation: parse_reasoning( generation, @@ -308,6 +304,11 @@ async def _single_data_point_generate(self, data_point, data): feedback = self.refine_prompt.fill({"error_message": last_error_message}) results_dict["feedback"] = feedback[0]["content"] else: + if self.sandbox is None: + raise RuntimeError( + "Sandbox is required for Lean4 code execution but was not configured. " + "Please provide sandbox configuration." + ) # execute_code returns (result_dict, session_id) tuple execution_result, _ = await self.sandbox.execute_code( full_code, language="lean4", timeout=600.0, max_output_characters=1000000 diff --git a/nemo_skills/pipeline/prover.py b/nemo_skills/pipeline/prover.py deleted file mode 100644 index a70a199f94..0000000000 --- a/nemo_skills/pipeline/prover.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib -import logging -import os -from typing import List - -import typer - -import nemo_skills.pipeline.utils as pipeline_utils -from nemo_skills.dataset.utils import import_from_path -from nemo_skills.inference import GENERATION_MODULE_MAP, GenerationType -from nemo_skills.pipeline.app import app, typer_unpacker -from nemo_skills.utils import ( - compute_chunk_ids, - get_logger_name, - setup_logging, - str_ids_to_list, -) - -LOG = logging.getLogger(get_logger_name(__file__)) - -# TODO: add num_jobs here for consistency with eval? - - -@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) -@typer_unpacker -def generate( - ctx: typer.Context, - cluster: str = typer.Option( - None, - help="One of the configs inside config_dir or NEMO_SKILLS_CONFIG_DIR or ./cluster_configs. " - "Can also use NEMO_SKILLS_CONFIG instead of specifying as argument.", - ), - input_file: str = typer.Option( - None, help="Path to the input data file. Can either specify input_file or input_dir, but not both. " - ), - input_dir: str = typer.Option( - None, - help="Path to the input data directory. Can either specify input_file or input_dir, but not both. " - "If input_file is not provided, will use output-rs{{seed}}.jsonl inside input_dir as input_files. " - "In this case, the random seed parameter is used both for input and for output files, which " - "means it's a 1-1 mapping (not 1-num_random_seeds as in the case of input_file).", - ), - output_dir: str = typer.Option(..., help="Where to put results"), - expname: str = typer.Option("generate", help="Nemo run experiment name"), - generation_type: GenerationType | None = typer.Option(None, help="Type of generation to perform"), - generation_module: str = typer.Option( - None, - help="Path to the generation module to use. " - "If not specified, will use the registered generation module for the " - "generation type (which is required in this case).", - ), - model: str = typer.Option(None, help="Path to the model or model name in API"), - server_address: str = typer.Option( - None, help="Use ip:port for self-hosted models or the API url if using model providers" - ), - server_type: pipeline_utils.SupportedServers = typer.Option(..., help="Type of server to use"), - server_gpus: int = typer.Option(None, help="Number of GPUs to use if hosting the model"), - server_nodes: int = typer.Option(1, help="Number of nodes required for hosting LLM server"), - server_args: str = typer.Option("", help="Any extra arguments to pass to the server"), - server_entrypoint: str = typer.Option( - None, - help="Path to the entrypoint of the server. " - "If not specified, will use the default entrypoint for the server type.", - ), - server_container: str = typer.Option( - None, help="Override container image for the hosted server (if server_gpus is set)" - ), - dependent_jobs: int = typer.Option(0, help="Specify this to launch that number of dependent jobs"), - mount_paths: str = typer.Option(None, help="Comma separated list of paths to mount on the remote machine"), - num_random_seeds: int = typer.Option( - None, help="Specify if want to run many generations with high temperature for the same input" - ), - random_seeds: str = typer.Option( - None, - help="List of random seeds to use for generation. Separate with , or .. to specify range. " - "Can provide a list directly when using through Python", - ), - starting_seed: int = typer.Option(0, help="Starting seed for random sampling"), - num_chunks: int = typer.Option( - None, - help="Number of chunks to split the dataset into. If None, will not chunk the dataset.", - ), - chunk_ids: str = typer.Option( - None, - help="List of explicit chunk ids to run. Separate with , or .. to specify range. " - "Can provide a list directly when using through Python", - ), - preprocess_cmd: str = typer.Option(None, help="Command to run before generation"), - postprocess_cmd: str = typer.Option(None, help="Command to run after generation"), - partition: str = typer.Option( - None, help="Can specify if need interactive jobs or a specific non-default partition" - ), - time_min: str = typer.Option(None, help="If specified, will use as a time-min slurm parameter"), - eval_args: str = typer.Option( - None, help="Specify if need to run nemo_skills/evaluation/evaluate_results.py on the generation outputs" - ), - run_after: List[str] = typer.Option( - None, help="Can specify a list of expnames that need to be completed before this one starts" - ), - reuse_code: bool = typer.Option( - True, - help="If True, will reuse the code from the provided experiment. " - "If you use it from Python, by default the code will be re-used from " - "the last submitted experiment in the current Python session, so set to False to disable " - "(or provide reuse_code_exp to override).", - ), - reuse_code_exp: str = typer.Option( - None, - help="If specified, will reuse the code from this experiment. " - "Can provide an experiment name or an experiment object if running from code.", - ), - config_dir: str = typer.Option(None, help="Can customize where we search for cluster configs"), - log_dir: str = typer.Option(None, help="Can specify a custom location for slurm logs."), - exclusive: bool = typer.Option(False, help="If set will add exclusive flag to the slurm job."), - rerun_done: bool = typer.Option( - False, help="If True, will re-run jobs even if a corresponding '.done' file already exists" - ), - with_sandbox: bool = typer.Option(False, help="If True, will start a sandbox container alongside this job"), - check_mounted_paths: bool = typer.Option(False, help="Check if mounted paths are available on the remote machine"), - log_samples: bool = typer.Option( - False, - help="If True, will log random samples from the output files to wandb. " - "Requires WANDB_API_KEY to be set in the environment. " - "Use wandb_name/wandb_group/wandb_project to specify where to log.", - ), - wandb_name: str = typer.Option( - None, - help="Name of the wandb group to sync samples to. If not specified, but log_samples=True, will use expname.", - ), - wandb_group: str = typer.Option(None, help="Name of the wandb group to sync samples to."), - wandb_project: str = typer.Option( - "nemo-skills", - help="Name of the wandb project to sync samples to.", - ), - installation_command: str | None = typer.Option( - None, - help="An installation command to run before main job. Only affects main task (not server or sandbox). " - "You can use an arbitrary command here and we will run it on a single rank for each node. " - "E.g. 'pip install my_package'", - ), - skip_hf_home_check: bool = typer.Option( - False, - help="If True, skip checking that HF_HOME env var is defined in the cluster config.", - ), - dry_run: bool = typer.Option(False, help="If True, will not run the job, but will validate all arguments."), - _reuse_exp: str = typer.Option(None, help="Internal option to reuse an experiment object.", hidden=True), - _task_dependencies: List[str] = typer.Option( - None, help="Internal option to specify task dependencies.", hidden=True - ), -): - """Generate LLM completions for a given input file. - - Run `python -m nemo_skills.inference.generate --help` for other supported arguments - (need to be prefixed with ++, since we use Hydra for that script). - """ - setup_logging(disable_hydra_logs=False, use_rich=True) - extra_arguments = f"{' '.join(ctx.args)}" - LOG.info("Starting generation job") - LOG.info("Extra arguments that will be passed to the underlying script: %s", extra_arguments) - - try: - server_type = server_type.value - except AttributeError: - pass - - if log_samples: - wandb_parameters = { - "name": wandb_name or expname, - "project": wandb_project, - "group": wandb_group, - } - else: - wandb_parameters = None - - get_random_port = pipeline_utils.should_get_random_port(server_gpus, exclusive, server_type) - - if random_seeds and num_random_seeds: - raise ValueError("Cannot specify both random_seeds and num_random_seeds") - if num_random_seeds: - random_seeds = list(range(starting_seed, starting_seed + num_random_seeds)) - if isinstance(random_seeds, str): - random_seeds = str_ids_to_list(random_seeds) - - if num_chunks: - chunk_ids = compute_chunk_ids(chunk_ids, num_chunks) - if chunk_ids is None: - chunk_ids = [None] - - # Prepare cluster config and mount paths - cluster_config = pipeline_utils.get_cluster_config(cluster, config_dir) - cluster_config = pipeline_utils.resolve_mount_paths( - cluster_config, mount_paths, create_remote_dir=check_mounted_paths - ) - - if not log_dir: - log_dir = f"{output_dir}/generation-logs" - - output_dir, log_dir = pipeline_utils.check_mounts( - cluster_config, - log_dir=log_dir, - mount_map={output_dir: None}, - check_mounted_paths=check_mounted_paths, - ) - - original_server_address = server_address - - if generation_module is not None and generation_type is not None: - raise ValueError("Cannot specify both generation_module and generation_type. ") - if generation_module is None: - generation_module = GENERATION_MODULE_MAP[generation_type or GenerationType.generate] - - if os.sep in generation_module: - generation_task = import_from_path(generation_module) - else: - generation_task = importlib.import_module(generation_module) - if not hasattr(generation_task, "GENERATION_TASK_CLASS"): - raise ValueError( - f"Module {generation_module} does not have a GENERATION_TASK_CLASS attribute. " - "Please provide a valid generation module." - ) - generation_task = generation_task.GENERATION_TASK_CLASS - extra_arguments = f"{generation_task.get_generation_default_args()} {extra_arguments}" - extra_arguments_original = extra_arguments - - # Treat no random seeds as a single None seed to unify the code paths - if not random_seeds: - random_seeds = [None] - - remaining_jobs = pipeline_utils.get_remaining_jobs( - cluster_config=cluster_config, - output_dir=output_dir, - random_seeds=random_seeds, - chunk_ids=chunk_ids, - rerun_done=rerun_done, - ) - has_tasks = False - all_tasks = [] - if _task_dependencies is None: - _task_dependencies = [] - with pipeline_utils.get_exp(expname, cluster_config, _reuse_exp) as exp: - for seed_idx, (seed, chunk_ids) in enumerate(remaining_jobs.items()): - if wandb_parameters: - # no need for chunks as it will run after merging - wandb_parameters["samples_file"] = pipeline_utils.get_chunked_rs_filename( - output_dir, - random_seed=seed, - chunk_id=None, - ) - for chunk_id in chunk_ids: - has_tasks = True - server_config, server_address, extra_arguments = pipeline_utils.configure_client( - model=model, - server_type=server_type, - server_address=original_server_address, - server_gpus=server_gpus, - server_nodes=server_nodes, - server_args=server_args, - server_entrypoint=server_entrypoint, - server_container=server_container, - extra_arguments=extra_arguments_original, - get_random_port=get_random_port, - ) - cmd = pipeline_utils.get_generation_cmd( - input_file=input_file, - input_dir=input_dir, - random_seed=seed, - output_dir=output_dir, - extra_arguments=extra_arguments, - eval_args=eval_args, - chunk_id=chunk_id, - num_chunks=num_chunks, - preprocess_cmd=preprocess_cmd, - postprocess_cmd=postprocess_cmd, - wandb_parameters=wandb_parameters if seed_idx == 0 else None, - script=generation_module, - ) - prev_tasks = _task_dependencies - for _ in range(dependent_jobs + 1): - task_name = f"{expname}-rs{seed}" if seed is not None else expname - if chunk_id is not None: - task_name += f"-chunk{chunk_id}" - new_task = pipeline_utils.add_task( - exp, - cmd=pipeline_utils.wrap_python_path(cmd=cmd), - task_name=task_name, - log_dir=log_dir, - container=cluster_config["containers"]["nemo-skills"], - cluster_config=cluster_config, - partition=partition, - time_min=time_min, - server_config=server_config, - with_sandbox=with_sandbox, - sandbox_port=None if get_random_port else 6000, - run_after=run_after, - reuse_code=reuse_code, - reuse_code_exp=reuse_code_exp, - task_dependencies=( - prev_tasks if cluster_config["executor"] == "slurm" else all_tasks + _task_dependencies - ), - get_server_command=generation_task.get_server_command_fn(), - slurm_kwargs={"exclusive": exclusive} if exclusive else None, - installation_command=installation_command, - skip_hf_home_check=skip_hf_home_check, - ) - prev_tasks = [new_task] - all_tasks.append(new_task) - if has_tasks and not _reuse_exp: # if we are reusing an experiment, the tasks will run from there - pipeline_utils.run_exp(exp, cluster_config, dry_run=dry_run) - - if _reuse_exp: - return all_tasks - else: - if has_tasks: - return exp - return None - - -if __name__ == "__main__": - typer.main.get_command_name = lambda name: name - app() From 1d2e92eee3ea73d2b575f213b5cc2390cbee3026 Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Mon, 8 Dec 2025 08:18:54 -0800 Subject: [PATCH 6/7] lean4_utils into proof_utils Signed-off-by: Stephen Ge --- nemo_skills/code_execution/proof_utils.py | 184 +++++++++++++++++++ nemo_skills/inference/lean4_utils.py | 205 ---------------------- nemo_skills/inference/prover.py | 14 +- 3 files changed, 191 insertions(+), 212 deletions(-) delete mode 100644 nemo_skills/inference/lean4_utils.py diff --git a/nemo_skills/code_execution/proof_utils.py b/nemo_skills/code_execution/proof_utils.py index fdf84a87cf..33c92891db 100644 --- a/nemo_skills/code_execution/proof_utils.py +++ b/nemo_skills/code_execution/proof_utils.py @@ -14,12 +14,16 @@ """Shared utilities for proof processing and evaluation.""" +import logging import re from dataclasses import dataclass from typing import Any, Dict from nemo_skills.code_execution.utils import clean_formal_generation from nemo_skills.dataset.utils import get_lean4_header +from nemo_skills.utils import get_logger_name + +LOG = logging.getLogger(get_logger_name(__file__)) @dataclass @@ -192,3 +196,183 @@ def prepare_predicted_proof_from_line_dict( return build_lean4_proof( generation=line_dict["generation"], data_point=line_dict, config=config, answer_format=answer_format ) + + +# ------------------------------------------------------------------------------------------------ +# The following code is adapted from https://github.com/Goedel-LM/Goedel-Prover-V2 +# Used for multi-turn proof refinement in the lean4_prover workflow +# ------------------------------------------------------------------------------------------------ + + +def remove_comments(text): + # First remove all /- ... -/ blocks + text = re.sub(r"/-.*?-/", "", text, flags=re.DOTALL) + # Then remove -- comments from each line + lines = text.split("\n") + cleaned_lines = [] + for line in lines: + cleaned_line = line.split("--", 1)[0] + if cleaned_line.strip() == "": + continue + cleaned_lines.append(cleaned_line) + # Join back together and remove excessive empty lines + cleaned_text = "\n".join(cleaned_lines) + return cleaned_text.strip() + + +def move_imports_to_beginning(input_string): + lines = input_string.split("\n") + import_lines = [line for line in lines if line.startswith("import")] + other_lines = [line for line in lines if not line.startswith("import")] + return "\n".join(import_lines + other_lines) + + +def return_theorem_to_prove(text): + # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between + pattern = r"((?:theorem).*?:=\s*by\s*sorry)" + match = re.search(pattern, text, re.DOTALL) + return match.span() if match else None + + +def return_theorem_to_replace(text): + # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between + pattern = r"((?:^|\s)theorem\s+.*?:=\s*by)" + match = re.search(pattern, text, re.DOTALL) + return match.span() if match else None + + +def replace_statement_in_proof(statement, proof): + if ("apply?" in proof) or ("exact?" in proof): + return "**Error**, 'apply?' or 'exact?' is used, which is not allowed." + stats_re = remove_comments(statement) + stats_span_ = return_theorem_to_prove(stats_re) + if stats_span_ is None: + error_app = "\n".join(["\n"] + ["-- " + x for x in statement.split("\n")]) + return f"**Error**, can not find 'theorem' and ':= sorry' in {error_app}" + proof_str = remove_comments(proof) + span = return_theorem_to_replace(proof_str) + if span is None: + error_app = "\n".join(["\n"] + ["-- " + x for x in proof.split("\n")]) + return f"**Error**, can not find 'theorem' and ':=' in {error_app}" + return stats_re[: stats_span_[1]].replace("sorry", "") + proof_str[span[1] :] + + +def refine_by_sorry(text): + # Define the regular expression pattern + target_pattern = r":=\s*(?:by\s*)?(?:sorry\s*)?" + replacement = ":= by sorry" # The new text we want to insert + # We construct the pattern with two capturing groups + # (group 1: the part from 'theorem' to just before our target) + # (group 2: the target pattern itself) + combined_pattern = r"(theorem.*?)(" + target_pattern + r")" + # Find the first match + match = re.search(combined_pattern, text, re.DOTALL) + if match: + # The part of the string BEFORE the target we want to replace + # We use match.start(2) which is the start of the second group (our target) + prefix = text[: match.start(2)] + # Concatenate the prefix with the replacement to get the final, truncated string + final_text = prefix + replacement + else: + final_text = text + return final_text + + +def extract_code(inputs): + import_head = ( + "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n" + ) + pattern = r"```lean4\n(.*?)\n```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + pattern = r"```lean4\n(.*?)```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + pattern = r"```lean\n(.*?)```" + matches = re.findall(pattern, inputs, re.DOTALL) + if matches: + return import_head + matches[-1] + return "None" + + +def parse_error(log_string): + """Parse Lean4 compiler error messages from log output.""" + error_pattern = re.compile( + r"(/lean4/my_project/.*?:\d+:\d+: error:.*?)(?=\n/lean4/my_project|\Z)", + re.DOTALL, + ) + errors = error_pattern.findall(log_string) + pattern = re.compile(r":(\d+):(\d+):") + error_list = [] + for error in errors: + match = pattern.search(error) + error_list.append( + { + "pos": {"line": int(match.group(1)), "column": int(match.group(2))}, + "endPos": None, + "data": error.split("error:")[1], + } + ) + + return error_list + + +def get_error_str(code, errors, error_thres=True): + """Format compiler errors with code context for display.""" + err_str = "" + code_lines = code.split("\n") + error_num_thres = 8 if error_thres else len(errors) + + for i, error in enumerate(errors[:error_num_thres]): + start_line = error["pos"]["line"] - 1 + start_col = error["pos"]["column"] + if start_line >= len(code_lines): + LOG.warning( + "Error line %d exceeds code length %d. Errors: %s, Code: %s", start_line, len(code_lines), errors, code + ) + continue + if error["endPos"] is None: + end_line = start_line + end_col = len(code_lines[start_line]) + else: + end_line = error["endPos"]["line"] - 1 + end_col = error["endPos"]["column"] + + err_str += f"\nError {i + 1}:\n" + err_str += "\nCorresponding Code:\n```lean4\n" + error_code = "" + for ii in range(-4, 0): + if start_line + ii >= 0: + error_code += f"{code_lines[start_line + ii]}\n" + if start_line != end_line: + error_code += code_lines[start_line][:start_col] + "" + code_lines[start_line][start_col:] + "\n" + if not error_thres: + for j in range(start_line + 1, end_line): + error_code += f"{code_lines[j]}\n" + else: + show_line = 6 + for j in range(start_line + 1, min(end_line, start_line + show_line)): + error_code += f"{code_lines[j]}\n" + if end_line > start_line + show_line: + leading_spaces = len(code_lines[j]) - len(code_lines[j].lstrip(" ")) + error_code += "\n" + " " * leading_spaces + "... --[Truncated]-- ...\n" + error_code += code_lines[end_line][:end_col] + "" + code_lines[end_line][end_col:] + "\n" + else: + error_code += ( + code_lines[start_line][:start_col] + + "" + + code_lines[start_line][start_col:end_col] + + "" + + code_lines[start_line][end_col:] + + "\n" + ) + if end_line + 1 < len(code_lines): + error_code += f"{code_lines[end_line + 1]}\n" + err_str += error_code + err_str += "\n```\n" + err_str += f"\nError Message: {error['data']}\n" + if len(errors) > error_num_thres: + err_str += f"\n... [Omitted {len(errors) - error_num_thres} more errors] ...\n" + return err_str diff --git a/nemo_skills/inference/lean4_utils.py b/nemo_skills/inference/lean4_utils.py deleted file mode 100644 index 58c26495c1..0000000000 --- a/nemo_skills/inference/lean4_utils.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from nemo_skills.utils import get_logger_name - -LOG = logging.getLogger(get_logger_name(__file__)) - -# ------------------------------------------------------------------------------------------------ -# The following code is adapted from https://github.com/Goedel-LM/Goedel-Prover-V2 -# ------------------------------------------------------------------------------------------------ - - -def remove_comments(text): # remove comments - # First remove all /- ... -/ blocks - text = re.sub(r"/-.*?-/", "", text, flags=re.DOTALL) - # Then remove -- comments from each line - lines = text.split("\n") - cleaned_lines = [] - for line in lines: - cleaned_line = line.split("--", 1)[0] - if cleaned_line.strip() == "": - continue - cleaned_lines.append(cleaned_line) - # Join back together and remove excessive empty lines - cleaned_text = "\n".join(cleaned_lines) - return cleaned_text.strip() - - -def move_imports_to_beginning(input_string): - lines = input_string.split("\n") - import_lines = [line for line in lines if line.startswith("import")] - other_lines = [line for line in lines if not line.startswith("import")] - return "\n".join(import_lines + other_lines) - - -def return_theorem_to_prove(text): - # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between - pattern = r"((?:theorem).*?:=\s*by\s*sorry)" - match = re.search(pattern, text, re.DOTALL) - return match.span() if match else None - - -def return_theorem_to_replace(text): - # Pattern that matches from 'theorem' or 'lemma' to ':= by sorry' with any content in between - pattern = r"((?:^|\s)theorem\s+.*?:=\s*by)" - match = re.search(pattern, text, re.DOTALL) - return match.span() if match else None - - -def replace_statement_in_proof(statement, proof): - if ("apply?" in proof) or ("exact?" in proof): - return "**Error**, 'apply?' or 'exact?' is used, which is not allowed." - stats_re = remove_comments(statement) - stats_span_ = return_theorem_to_prove(stats_re) - if stats_span_ is None: - error_app = "\n".join(["\n"] + ["-- " + x for x in statement.split("\n")]) - return f"**Error**, can not find 'theorem' and ':= sorry' in {error_app}" - proof_str = remove_comments(proof) - span = return_theorem_to_replace(proof_str) - if span is None: - error_app = "\n".join(["\n"] + ["-- " + x for x in proof.split("\n")]) - return f"**Error**, can not find 'theorem' and ':=' in {error_app}" - return stats_re[: stats_span_[1]].replace("sorry", "") + proof_str[span[1] :] - - -def refine_by_sorry(text): - # Define the regular expression pattern - target_pattern = r":=\s*(?:by\s*)?(?:sorry\s*)?" - replacement = ":= by sorry" # The new text we want to insert - # We construct the pattern with two capturing groups - # (group 1: the part from 'theorem' to just before our target) - # (group 2: the target pattern itself) - combined_pattern = r"(theorem.*?)(" + target_pattern + r")" - # Find the first match - match = re.search(combined_pattern, text, re.DOTALL) - if match: - # The part of the string BEFORE the target we want to replace - # We use match.start(2) which is the start of the second group (our target) - prefix = text[: match.start(2)] - # Concatenate the prefix with the replacement to get the final, truncated string - final_text = prefix + replacement - else: - final_text = text - return final_text - - -def extract_code(inputs): - import_head = ( - "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n" - ) - pattern = r"```lean4\n(.*?)\n```" - matches = re.findall(pattern, inputs, re.DOTALL) - if matches: - return import_head + matches[-1] - pattern = r"```lean4\n(.*?)```" - matches = re.findall(pattern, inputs, re.DOTALL) - if matches: - return import_head + matches[-1] - pattern = r"```lean\n(.*?)```" - matches = re.findall(pattern, inputs, re.DOTALL) - if matches: - return import_head + matches[-1] - return "None" - - -def parse_error(log_string): - # Pattern to match multiline warnings - # warning_pattern = re.compile( - # r"(/lean4/my_project/.*?:\d+:\d+: warning:.*?)(?=\n/lean4/my_project|\Z)", - # re.DOTALL, - # ) - # Pattern to match multiline errors - error_pattern = re.compile( - r"(/lean4/my_project/.*?:\d+:\d+: error:.*?)(?=\n/lean4/my_project|\Z)", - re.DOTALL, - ) - # Find all warnings and errors - # warnings = warning_pattern.findall(log_string) - errors = error_pattern.findall(log_string) - pattern = re.compile(r":(\d+):(\d+):") - error_list = [] - for error in errors: - match = pattern.search(error) - error_list.append( - { - "pos": {"line": int(match.group(1)), "column": int(match.group(2))}, - "endPos": None, - "data": error.split("error:")[1], - } - ) - - return error_list - - -def get_error_str(code, errors, error_thres=True): - err_str = "" - code_lines = code.split("\n") - # token_lengths = [len(line) + 1 for line in code_lines] - error_num_thres = 8 if error_thres else len(errors) - - for i, error in enumerate(errors[:error_num_thres]): - start_line = error["pos"]["line"] - 1 - start_col = error["pos"]["column"] - if start_line >= len(code_lines): - LOG.warning( - "Error line %d exceeds code length %d. Errors: %s, Code: %s", start_line, len(code_lines), errors, code - ) - continue - if error["endPos"] is None: - end_line = start_line - end_col = len(code_lines[start_line]) - else: - end_line = error["endPos"]["line"] - 1 - end_col = error["endPos"]["column"] - - err_str += f"\nError {i + 1}:\n" - err_str += "\nCorresponding Code:\n```lean4\n" - error_code = "" - for ii in range(-4, 0): - if start_line + ii >= 0: - error_code += f"{code_lines[start_line + ii]}\n" - if start_line != end_line: - error_code += code_lines[start_line][:start_col] + "" + code_lines[start_line][start_col:] + "\n" - if not error_thres: - for j in range(start_line + 1, end_line): - error_code += f"{code_lines[j]}\n" - else: - show_line = 6 - for j in range(start_line + 1, min(end_line, start_line + show_line)): - error_code += f"{code_lines[j]}\n" - if end_line > start_line + show_line: - leading_spaces = len(code_lines[j]) - len(code_lines[j].lstrip(" ")) - error_code += "\n" + " " * leading_spaces + "... --[Truncated]-- ...\n" - error_code += code_lines[end_line][:end_col] + "" + code_lines[end_line][end_col:] + "\n" - else: - error_code += ( - code_lines[start_line][:start_col] - + "" - + code_lines[start_line][start_col:end_col] - + "" - + code_lines[start_line][end_col:] - + "\n" - ) - if end_line + 1 < len(code_lines): - error_code += f"{code_lines[end_line + 1]}\n" - err_str += error_code - err_str += "\n```\n" - err_str += f"\nError Message: {error['data']}\n" - if len(errors) > error_num_thres: - err_str += f"\n... [Omitted {len(errors) - error_num_thres} more errors] ...\n" - return err_str diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 4310698fc9..1ecd3abd83 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -21,6 +21,13 @@ import hydra from transformers import AutoTokenizer +from nemo_skills.code_execution.proof_utils import ( + extract_code, + get_error_str, + parse_error, + refine_by_sorry, + replace_statement_in_proof, +) from nemo_skills.code_execution.sandbox import get_sandbox, sandbox_params from nemo_skills.inference.model import get_model, server_params from nemo_skills.inference.model.base import EndpointType @@ -34,13 +41,6 @@ ) from .generate import GenerateSolutionsConfig, GenerationTask -from .lean4_utils import ( - extract_code, - get_error_str, - parse_error, - refine_by_sorry, - replace_statement_in_proof, -) LOG = logging.getLogger(get_logger_name(__file__)) From 761be61fe0ff50eb5462e42520fc6be93b9ee49c Mon Sep 17 00:00:00 2001 From: Stephen Ge Date: Tue, 9 Dec 2025 10:16:25 -0800 Subject: [PATCH 7/7] defer to super Signed-off-by: Stephen Ge --- nemo_skills/inference/prover.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 1ecd3abd83..848c7c70dd 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -28,8 +28,8 @@ refine_by_sorry, replace_statement_in_proof, ) -from nemo_skills.code_execution.sandbox import get_sandbox, sandbox_params -from nemo_skills.inference.model import get_model, server_params +from nemo_skills.code_execution.sandbox import sandbox_params +from nemo_skills.inference.model import server_params from nemo_skills.inference.model.base import EndpointType from nemo_skills.prompt.utils import get_prompt from nemo_skills.utils import ( @@ -120,11 +120,8 @@ def log_example_prompt(self, data): def setup_llm(self): if self.cfg.code_execution: - raise ValueError("Code execution is not supported for prover") - # Store sandbox directly on self for Lean4 code execution - self.sandbox = get_sandbox(**self.cfg.sandbox) if self.cfg.sandbox is not None else None - llm = get_model(**self.cfg.server, tokenizer=self.tokenizer) - return llm + raise ValueError("Code execution is not supported for prover. Use sandbox config for Lean4 execution.") + return super().setup_llm() def setup_refine_prompt(self): assert self.cfg.refinement_prompt_config is not None, (