Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions nemo_skills/inference/prover.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,62 @@ def _transform_for_nemotron_refinement(self, proof_attempt: str, error_message:
}
)

def _parse_gpt_oss_output(self, content: str) -> tuple[str, str | None]:
"""Parse gpt-oss model output to extract thinking and final content.

gpt-oss models output in the format:
<|channel|>analysis<|message|>...thinking...<|end|><|start|>assistant<|channel|>final<|message|>...final...<|return|>

The chat template expects analysis content in 'thinking' field and final content in 'content' field.

Returns:
tuple of (final_content, thinking_content or None)
"""
import re

# Check if the content contains gpt-oss channel tags
if "<|channel|>" not in content:
return content, None

thinking = None
final_content = content

# Extract analysis/thinking content: between <|channel|>analysis<|message|> and <|end|>
analysis_pattern = r"<\|channel\|>analysis[^<]*<\|message\|>(.*?)<\|end\|>"
analysis_match = re.search(analysis_pattern, content, re.DOTALL)
if analysis_match:
thinking = analysis_match.group(1).strip()

# Extract final content: after <|channel|>final<|message|> until <|return|> or end
final_pattern = r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)"
final_match = re.search(final_pattern, content, re.DOTALL)
if final_match:
final_content = final_match.group(1).strip()
else:
# If no final channel found, try to strip all channel tags and use what remains
# This handles cases where the format might be slightly different
final_content = re.sub(r"<\|[^|]+\|>", "", content).strip()

return final_content, thinking

def _make_assistant_message(self, content: str, reasoning_content: str | None = None) -> dict:
"""Create an assistant message dict, optionally with thinking/reasoning content.

Some models (e.g., gpt-oss) output <|channel|> tags that need to be in a separate
'thinking' field rather than in 'content' for the chat template to work correctly.

If reasoning_content is not provided, attempts to parse it from content if the content
contains gpt-oss channel tags.
"""
# If reasoning_content not provided, try to parse from content
if reasoning_content is None:
content, reasoning_content = self._parse_gpt_oss_output(content)

message = {"role": "assistant", "content": content}
if reasoning_content:
message["thinking"] = reasoning_content
return message

async def _single_data_point_generate(self, data_point, data):
formal_statement = (
(data_point["header"].strip() + "\n")
Expand Down Expand Up @@ -243,8 +299,11 @@ async def _single_data_point_generate(self, data_point, data):
),
)

# Get reasoning_content if available (e.g., from gpt-oss models)
reasoning_content = generation.get("reasoning_content")

new_prompt_turn_list = deepcopy(prompt_turn_list)
new_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}]
new_prompt_turn_list.append(self._make_assistant_message(generation["generation"], reasoning_content))

prompt_turn_list_list.append(
new_prompt_turn_list
Expand All @@ -259,22 +318,15 @@ async def _single_data_point_generate(self, data_point, data):
): # check if successfully parse the code. We do not want to delete the turn if there is a parsing error.
if self.cfg.delete_wrong_turns:
prompt_turn_list = deepcopy(base_prompt_turn_list) + [
{
"role": "assistant",
"content": f"```lean4\n{full_code.strip()}\n```",
}
self._make_assistant_message(f"```lean4\n{full_code.strip()}\n```")
] # only keep the latest turn
else:
prompt_turn_list += [
{
"role": "assistant",
"content": f"```lean4\n{full_code.strip()}\n```",
}
]
full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}]
prompt_turn_list.append(self._make_assistant_message(f"```lean4\n{full_code.strip()}\n```"))
full_prompt_turn_list.append(self._make_assistant_message(generation["generation"], reasoning_content))
else:
prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}]
full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}]
assistant_msg = self._make_assistant_message(generation["generation"], reasoning_content)
prompt_turn_list.append(assistant_msg)
full_prompt_turn_list.append(assistant_msg)

if code == "None" or "**Error**" in full_code:
if code == "None":
Expand Down