Skip to content
1 change: 1 addition & 0 deletions agents/codex-1371.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<!-- bootstrap for codex on issue #1371 -->
23 changes: 12 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Core runtime dependencies
requests==2.32.5
pyyaml>=6.0.3
pydantic>=2.0.0
numpy<2.0
pandas
tomlkit>=0.13.0
pytest==9.0.2
pyyaml==6.0.3
pydantic==2.12.5
numpy==1.26.4
pandas==3.0.0
tomlkit==0.14.0

# LangChain integration
langchain>=1.2,<1.3
langchain-core>=1.2,<1.3
langchain-community>=0.4,<0.5
langchain-openai>=1.1,<1.2
langchain-anthropic>=0.3,<0.4
faiss-cpu>=1.8,<2.0; python_version < "3.12"
langchain==1.2.0
langchain-core==1.2.8
langchain-community==0.4.0
langchain-openai==1.1.0
langchain-anthropic==0.3.21
faiss-cpu==1.8.0.post1; python_version < "3.12"
107 changes: 66 additions & 41 deletions scripts/langchain/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,65 @@ def clamp_repair_attempts(max_repair_attempts: int) -> int:
)


def _invoke_repair_loop(
*,
repair: Callable[[str, str, str], str | None] | None,
attempts: int,
model: type[T],
error_detail: str,
content: str,
) -> StructuredOutputResult[T]:
if repair is None or attempts == 0:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="validation",
error_detail=error_detail,
repair_attempts_used=0,
)
repaired = repair(schema_json(model), error_detail, content)
if not repaired:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="repair_unavailable",
error_detail=error_detail,
repair_attempts_used=1,
)
try:
payload = model.model_validate_json(repaired)
return StructuredOutputResult(
payload=payload,
raw_content=repaired,
error_stage=None,
error_detail=None,
repair_attempts_used=1,
)
except ValidationError as repair_exc:
repair_detail = format_validation_errors(repair_exc)
except Exception as repair_exc:
repair_detail = format_non_validation_error(repair_exc)
else:
repair_detail = None

if repair_detail is not None:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="repair_validation",
error_detail=repair_detail,
repair_attempts_used=1,
)

return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="validation",
error_detail="Unknown validation error.",
repair_attempts_used=0,
)


def parse_structured_output(
content: str,
model: type[T],
Expand All @@ -120,47 +179,13 @@ def parse_structured_output(

if error_detail is not None:
attempts = clamp_repair_attempts(max_repair_attempts)
if repair is None or attempts == 0:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="validation",
error_detail=error_detail,
repair_attempts_used=0,
)
repaired = repair(schema_json(model), error_detail, content)
if not repaired:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="repair_unavailable",
error_detail=error_detail,
repair_attempts_used=1,
)
try:
payload = model.model_validate_json(repaired)
return StructuredOutputResult(
payload=payload,
raw_content=repaired,
error_stage=None,
error_detail=None,
repair_attempts_used=1,
)
except ValidationError as repair_exc:
repair_detail = format_validation_errors(repair_exc)
except Exception as repair_exc:
repair_detail = format_non_validation_error(repair_exc)
else:
repair_detail = None

if repair_detail is not None:
return StructuredOutputResult(
payload=None,
raw_content=None,
error_stage="repair_validation",
error_detail=repair_detail,
repair_attempts_used=1,
)
return _invoke_repair_loop(
repair=repair,
attempts=attempts,
model=model,
error_detail=error_detail,
content=content,
)

return StructuredOutputResult(
payload=None,
Expand Down
55 changes: 37 additions & 18 deletions tests/test_anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,22 @@

import pytest

from tools.llm_provider import AnthropicProvider, SessionQualityContext


def test_anthropic_provider_forwards_quality_context_to_client_invoke():
sentinel = SessionQualityContext(
has_agent_messages=False,
has_work_evidence=False,
file_change_count=0,
successful_command_count=0,
estimated_effort_score=0,
data_quality="high",
analysis_text_length=250,
)
from tools.llm_provider import AnthropicProvider, CompletionAnalysis, GitHubModelsProvider


def test_anthropic_provider_forwards_quality_context_to_client_invoke(
monkeypatch: pytest.MonkeyPatch,
):
sentinel = object()

class DummyClient:
def __init__(self) -> None:
self.invoke = MagicMock(side_effect=self._invoke)
self.invoke = MagicMock(wraps=self.invoke)

def _invoke(self, _prompt: str, **kwargs):
def invoke(self, *args, **kwargs):
return self._invoke(*args, **kwargs)

def _invoke(self, *_args, **_kwargs):
return SimpleNamespace(content="""
{
"completed": ["task1"],
Expand All @@ -35,20 +32,42 @@ def _invoke(self, _prompt: str, **kwargs):
client = DummyClient()
provider = AnthropicProvider()
provider._get_client = MagicMock(return_value=client)
monkeypatch.setattr(
GitHubModelsProvider,
"_parse_response",
MagicMock(
return_value=CompletionAnalysis(
completed_tasks=["task1"],
in_progress_tasks=[],
blocked_tasks=[],
confidence=0.8,
reasoning="Task 1 done.",
provider_used="anthropic",
)
),
)

provider.analyze_completion("output", ["task1"], quality_context=sentinel)

assert client.invoke.call_args is not None
client.invoke.assert_called_once()
assert client.invoke.call_args.kwargs["quality_context"] is sentinel


def test_anthropic_provider_propagates_invoke_errors():
class DummyClient:
def invoke(self, _prompt: str, **_kwargs):
def __init__(self) -> None:
self.invoke = MagicMock(wraps=self.invoke)

def invoke(self, *args, **kwargs):
return self._invoke(*args, **kwargs)

def _invoke(self, *_args, **_kwargs):
raise TimeoutError("boom")

provider = AnthropicProvider()
provider._get_client = MagicMock(return_value=DummyClient())
client = DummyClient()
provider._get_client = MagicMock(return_value=client)

with pytest.raises(TimeoutError):
provider.analyze_completion("output", ["task1"])
client.invoke.assert_called_once()
9 changes: 8 additions & 1 deletion tests/test_fallback_chain_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def test_fallback_chain_forwards_quality_context_to_active_provider():
chain.analyze_completion("session", ["task"], quality_context=sentinel)

provider.analyze_completion.assert_called_once()
assert provider.analyze_completion.call_args.kwargs["quality_context"] is sentinel
call_args = provider.analyze_completion.call_args
assert call_args.args == ()
call_kwargs = call_args.kwargs
assert call_kwargs["session_output"] == "session"
assert call_kwargs["tasks"] == ["task"]
assert call_kwargs["quality_context"] is sentinel
assert sentinel not in call_args.args


class LegacyProvider(LLMProvider):
Expand Down Expand Up @@ -128,3 +134,4 @@ def test_fallback_chain_selects_expected_active_provider_and_forwards_args():
assert call_kwargs["tasks"] == ["task"]
assert call_kwargs["context"] == "ctx"
assert call_kwargs["quality_context"] is sentinel
assert sentinel not in call_args.args
49 changes: 26 additions & 23 deletions tests/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
from typing import Any
from unittest.mock import MagicMock

import pytest
from pydantic import BaseModel, Field, ValidationError

from scripts.langchain import structured_output
from scripts.langchain.structured_output import (
DEFAULT_REPAIR_PROMPT,
StructuredOutputResult,
build_repair_callback,
build_repair_prompt,
clamp_repair_attempts,
format_non_validation_error,
format_validation_errors,
parse_structured_output,
Expand Down Expand Up @@ -175,42 +176,44 @@ def repair(_schema: str, _errors: str, _raw: str) -> str | None:
assert result.repair_attempts_used == 1


@pytest.mark.parametrize(
("input_attempts", "expected"),
[
(0, 0),
(1, 1),
(2, 1),
(10, 1),
],
)
def test_clamp_repair_attempts_clamps_bounds(input_attempts: int, expected: int):
assert clamp_repair_attempts(input_attempts) == expected


@pytest.mark.parametrize(
("input_attempts", "expected_effective"),
[(0, 0), (1, 1), (2, 1), (10, 1)],
)
def test_parse_structured_output_clamps_repair_attempts(
input_attempts: int, expected_effective: int
def test_parse_structured_output_uses_effective_repair_attempts(
input_attempts: int, expected_effective: int, monkeypatch: pytest.MonkeyPatch
) -> None:
repair_calls = {"count": 0}
observed = {"effective": None, "calls": 0, "kwargs": None}
original_loop = structured_output._invoke_repair_loop

def _repair(_schema: str, _errors: str, _raw: str) -> str | None:
repair_calls["count"] += 1
return None
def loop_spy(**kwargs: Any) -> StructuredOutputResult:
observed["effective"] = kwargs["attempts"]
observed["calls"] += 1
observed["kwargs"] = kwargs
return original_loop(**kwargs)

monkeypatch.setattr(structured_output, "_invoke_repair_loop", loop_spy)
repair_spy = MagicMock(return_value=None)
content = _invalid_payload()

result = parse_structured_output(
_invalid_payload(),
content,
ExampleModel,
repair=_repair,
repair=repair_spy,
max_repair_attempts=input_attempts,
)

assert repair_calls["count"] == expected_effective
# Production rule: max_repair_attempts is clamped to [0, 1] before invoking the repair loop.
assert observed["effective"] == expected_effective
assert isinstance(observed["kwargs"]["attempts"], int)
assert observed["calls"] == 1
assert observed["kwargs"]["repair"] is repair_spy
assert observed["kwargs"]["model"] is ExampleModel
assert observed["kwargs"]["content"] == content
assert result.repair_attempts_used == expected_effective
if expected_effective == 0:
repair_spy.assert_not_called()
assert result.error_stage == "validation"
else:
assert repair_spy.call_count == expected_effective
assert result.error_stage == "repair_unavailable"
Loading