Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agents/codex-1395.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<!-- bootstrap for codex on issue #1395 -->
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pytest==9.0.2
pyyaml==6.0.3
pydantic==2.12.5
numpy==2.4.2
pandas==3.0.0
pandas==2.3.3
tomlkit==0.14.0

# LangChain integration
Expand Down
19 changes: 18 additions & 1 deletion scripts/langchain/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def _invoke_repair_loop(
)


def invoke_repair_loop(
*,
repair: Callable[[str, str, str], str | None] | None,
attempts: int,
model: type[T],
error_detail: str,
content: str,
) -> StructuredOutputResult[T]:
return _invoke_repair_loop(
repair=repair,
attempts=attempts,
model=model,
error_detail=error_detail,
content=content,
)


def parse_structured_output(
content: str,
model: type[T],
Expand All @@ -179,7 +196,7 @@ def parse_structured_output(

if error_detail is not None:
attempts = clamp_repair_attempts(max_repair_attempts)
return _invoke_repair_loop(
return invoke_repair_loop(
repair=repair,
attempts=attempts,
model=model,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_fallback_chain_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,53 @@ def analyze_completion(
)


class BackupQualityProvider(LLMProvider):
@property
def name(self) -> str:
return "backup-quality-provider"

def is_available(self) -> bool:
return True

def analyze_completion(
self,
session_output: str,
tasks: list[str],
context: str | None = None,
quality_context: object | None = None,
) -> CompletionAnalysis:
_ = session_output
_ = tasks
_ = context
_ = quality_context
return CompletionAnalysis(
completed_tasks=[],
in_progress_tasks=[],
blocked_tasks=[],
confidence=0.7,
reasoning="backup-quality",
provider_used=self.name,
)


def test_fallback_chain_selects_expected_active_provider_and_forwards_args():
sentinel = object()
legacy_provider = LegacyProvider()
quality_provider = QualityProvider()
backup_quality_provider = BackupQualityProvider()
legacy_provider.analyze_completion = MagicMock(wraps=legacy_provider.analyze_completion)
quality_provider.analyze_completion = MagicMock(wraps=quality_provider.analyze_completion)
backup_quality_provider.analyze_completion = MagicMock(
wraps=backup_quality_provider.analyze_completion
)

chain = FallbackChainProvider([legacy_provider, quality_provider])
chain = FallbackChainProvider([legacy_provider, quality_provider, backup_quality_provider])
chain.analyze_completion("session", ["task"], "ctx", quality_context=sentinel)

assert chain._active_provider is quality_provider
assert legacy_provider.analyze_completion.call_count == 0
quality_provider.analyze_completion.assert_called_once()
assert backup_quality_provider.analyze_completion.call_count == 0
call_args = quality_provider.analyze_completion.call_args
assert call_args.args == ()
call_kwargs = call_args.kwargs
Expand Down
24 changes: 9 additions & 15 deletions tests/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,15 @@ def repair(_schema: str, _errors: str, _raw: str) -> str | None:
def test_parse_structured_output_uses_effective_repair_attempts(
input_attempts: int, expected_effective: int, monkeypatch: pytest.MonkeyPatch
) -> None:
observed = {"effective": None, "calls": 0, "kwargs": None}
original_loop = structured_output._invoke_repair_loop
observed = {"attempts": None, "invoke_calls": 0}
original_invoke = structured_output.invoke_repair_loop

def loop_spy(**kwargs: Any) -> StructuredOutputResult:
observed["effective"] = kwargs["attempts"]
observed["calls"] += 1
observed["kwargs"] = kwargs
return original_loop(**kwargs)
def invoke_spy(*, attempts: int, **kwargs: object) -> StructuredOutputResult:
observed["attempts"] = attempts
observed["invoke_calls"] += 1
return original_invoke(attempts=attempts, **kwargs)

monkeypatch.setattr(structured_output, "_invoke_repair_loop", loop_spy)
monkeypatch.setattr(structured_output, "invoke_repair_loop", invoke_spy)
repair_spy = MagicMock(return_value=None)
content = _invalid_payload()

Expand All @@ -203,13 +202,8 @@ def loop_spy(**kwargs: Any) -> StructuredOutputResult:
max_repair_attempts=input_attempts,
)

# Production rule: max_repair_attempts is clamped to [0, 1] before invoking the repair loop.
assert observed["effective"] == expected_effective
assert isinstance(observed["kwargs"]["attempts"], int)
assert observed["calls"] == 1
assert observed["kwargs"]["repair"] is repair_spy
assert observed["kwargs"]["model"] is ExampleModel
assert observed["kwargs"]["content"] == content
assert observed["attempts"] == expected_effective # Production rule: clamp to [0, 1].
assert observed["invoke_calls"] == 1
assert result.repair_attempts_used == expected_effective
if expected_effective == 0:
repair_spy.assert_not_called()
Expand Down
Loading