Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic_ai.models.instrumented import InstrumentedModel

from ..exceptions import FallbackExceptionGroup, ModelHTTPError
from ..settings import merge_model_settings
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,8 +66,9 @@ async def request(

for model in self.models:
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
merged_settings = merge_model_settings(model.settings, model_settings)
try:
response = await model.request(messages, model_settings, customized_model_request_parameters)
response = await model.request(messages, merged_settings, customized_model_request_parameters)
except Exception as exc:
if self._fallback_on(exc):
exceptions.append(exc)
Expand All @@ -91,10 +93,13 @@ async def request_stream(

for model in self.models:
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
merged_settings = merge_model_settings(model.settings, model_settings)
async with AsyncExitStack() as stack:
try:
response = await stack.enter_async_context(
model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
model.request_stream(
messages, merged_settings, customized_model_request_parameters, run_context
)
)
except Exception as exc:
if self._fallback_on(exc):
Expand Down
68 changes: 68 additions & 0 deletions tests/models/test_fallback.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import json
import sys
from collections.abc import AsyncIterator
from datetime import timezone
from typing import Any

import pytest
from dirty_equals import IsJson
from inline_snapshot import snapshot
from pydantic_core import to_json

from pydantic_ai import Agent, ModelHTTPError
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import Usage

from ..conftest import IsNow, try_import
Expand Down Expand Up @@ -445,3 +449,67 @@ async def test_fallback_condition_tuple() -> None:

response = await agent.run('hello')
assert response.output == 'success'


async def test_fallback_model_settings_merge():
"""Test that FallbackModel properly merges model settings from wrapped model and runtime settings."""

def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())])

base_model = FunctionModel(return_settings, settings=ModelSettings(temperature=0.1, max_tokens=1024))
fallback_model = FallbackModel(base_model)

# Test that base model settings are preserved when no additional settings are provided
agent = Agent(fallback_model)
result = await agent.run('Hello')
assert result.output == IsJson({'max_tokens': 1024, 'temperature': 0.1})

# Test that runtime model_settings are merged with base settings
agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5, parallel_tool_calls=True))
result = await agent_with_settings.run('Hello')
expected = {'max_tokens': 1024, 'temperature': 0.5, 'parallel_tool_calls': True}
assert result.output == IsJson(expected)

# Test that run-time model_settings override both base and agent settings
result = await agent_with_settings.run(
'Hello', model_settings=ModelSettings(temperature=0.9, extra_headers={'runtime_setting': 'runtime_value'})
)
expected = {
'max_tokens': 1024,
'temperature': 0.9,
'parallel_tool_calls': True,
'extra_headers': {
'runtime_setting': 'runtime_value',
},
}
assert result.output == IsJson(expected)


async def test_fallback_model_settings_merge_streaming():
"""Test that FallbackModel properly merges model settings in streaming mode."""

async def return_settings_stream(_: list[ModelMessage], info: AgentInfo):
# Yield the merged settings as JSON to verify they were properly combined
yield to_json(info.model_settings).decode()

base_model = FunctionModel(
stream_function=return_settings_stream,
settings=ModelSettings(temperature=0.1, extra_headers={'anthropic-beta': 'context-1m-2025-08-07'}),
)
fallback_model = FallbackModel(base_model)

# Test that base model settings are preserved in streaming mode
agent = Agent(fallback_model)
async with agent.run_stream('Hello') as result:
output = await result.get_output()

assert json.loads(output) == {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.1}

# Test that runtime model_settings are merged with base settings in streaming mode
agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5))
async with agent_with_settings.run_stream('Hello') as result:
output = await result.get_output()

expected = {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.5}
assert json.loads(output) == expected