diff --git a/nemo_skills/inference/prover.py b/nemo_skills/inference/prover.py index 848c7c70dd..711b5fe9aa 100644 --- a/nemo_skills/inference/prover.py +++ b/nemo_skills/inference/prover.py @@ -195,6 +195,62 @@ def _transform_for_nemotron_refinement(self, proof_attempt: str, error_message: } ) + def _parse_gpt_oss_output(self, content: str) -> tuple[str, str | None]: + """Parse gpt-oss model output to extract thinking and final content. + + gpt-oss models output in the format: + <|channel|>analysis<|message|>...thinking...<|end|><|start|>assistant<|channel|>final<|message|>...final...<|return|> + + The chat template expects analysis content in 'thinking' field and final content in 'content' field. + + Returns: + tuple of (final_content, thinking_content or None) + """ + import re + + # Check if the content contains gpt-oss channel tags + if "<|channel|>" not in content: + return content, None + + thinking = None + final_content = content + + # Extract analysis/thinking content: between <|channel|>analysis<|message|> and <|end|> + analysis_pattern = r"<\|channel\|>analysis[^<]*<\|message\|>(.*?)<\|end\|>" + analysis_match = re.search(analysis_pattern, content, re.DOTALL) + if analysis_match: + thinking = analysis_match.group(1).strip() + + # Extract final content: after <|channel|>final<|message|> until <|return|> or end + final_pattern = r"<\|channel\|>final<\|message\|>(.*?)(?:<\|return\|>|$)" + final_match = re.search(final_pattern, content, re.DOTALL) + if final_match: + final_content = final_match.group(1).strip() + else: + # If no final channel found, try to strip all channel tags and use what remains + # This handles cases where the format might be slightly different + final_content = re.sub(r"<\|[^|]+\|>", "", content).strip() + + return final_content, thinking + + def _make_assistant_message(self, content: str, reasoning_content: str | None = None) -> dict: + """Create an assistant message dict, optionally with thinking/reasoning content. + + Some models (e.g., gpt-oss) output <|channel|> tags that need to be in a separate + 'thinking' field rather than in 'content' for the chat template to work correctly. + + If reasoning_content is not provided, attempts to parse it from content if the content + contains gpt-oss channel tags. + """ + # If reasoning_content not provided, try to parse from content + if reasoning_content is None: + content, reasoning_content = self._parse_gpt_oss_output(content) + + message = {"role": "assistant", "content": content} + if reasoning_content: + message["thinking"] = reasoning_content + return message + async def _single_data_point_generate(self, data_point, data): formal_statement = ( (data_point["header"].strip() + "\n") @@ -243,8 +299,11 @@ async def _single_data_point_generate(self, data_point, data): ), ) + # Get reasoning_content if available (e.g., from gpt-oss models) + reasoning_content = generation.get("reasoning_content") + new_prompt_turn_list = deepcopy(prompt_turn_list) - new_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + new_prompt_turn_list.append(self._make_assistant_message(generation["generation"], reasoning_content)) prompt_turn_list_list.append( new_prompt_turn_list @@ -259,22 +318,15 @@ async def _single_data_point_generate(self, data_point, data): ): # check if successfully parse the code. We do not want to delete the turn if there is a parsing error. if self.cfg.delete_wrong_turns: prompt_turn_list = deepcopy(base_prompt_turn_list) + [ - { - "role": "assistant", - "content": f"```lean4\n{full_code.strip()}\n```", - } + self._make_assistant_message(f"```lean4\n{full_code.strip()}\n```") ] # only keep the latest turn else: - prompt_turn_list += [ - { - "role": "assistant", - "content": f"```lean4\n{full_code.strip()}\n```", - } - ] - full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + prompt_turn_list.append(self._make_assistant_message(f"```lean4\n{full_code.strip()}\n```")) + full_prompt_turn_list.append(self._make_assistant_message(generation["generation"], reasoning_content)) else: - prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] - full_prompt_turn_list += [{"role": "assistant", "content": generation["generation"]}] + assistant_msg = self._make_assistant_message(generation["generation"], reasoning_content) + prompt_turn_list.append(assistant_msg) + full_prompt_turn_list.append(assistant_msg) if code == "None" or "**Error**" in full_code: if code == "None":