Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dataclasses import dataclass, field, replace
from typing import Any, Union

from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
ModelResponsePart,
Expand Down Expand Up @@ -72,7 +71,7 @@ def handle_text_delta(
*,
vendor_part_id: VendorId | None,
content: str,
extract_think_tags: bool = False,
thinking_tags: tuple[str, str] | None = None,
) -> ModelResponseStreamEvent | None:
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.

Expand All @@ -85,7 +84,7 @@ def handle_text_delta(
of text. If None, a new part will be created unless the latest part is already
a TextPart.
content: The text content to append to the appropriate TextPart.
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.
thinking_tags: If provided, will handle content between the thinking tags as thinking parts.

Returns:
- A `PartStartEvent` if a new part was created.
Expand All @@ -110,10 +109,10 @@ def handle_text_delta(
if part_index is not None:
existing_part = self._parts[part_index]

if extract_think_tags and isinstance(existing_part, ThinkingPart):
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
if content == END_THINK_TAG:
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
if thinking_tags and isinstance(existing_part, ThinkingPart):
# We may be building a thinking part instead of a text part if we had previously seen a thinking tag
if content == thinking_tags[1]:
# When we see the thinking end tag, we're done with the thinking part and the next text delta will need a new part
self._vendor_id_to_part_index.pop(vendor_part_id)
return None
else:
Expand All @@ -123,8 +122,8 @@ def handle_text_delta(
else:
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')

if extract_think_tags and content == START_THINK_TAG:
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
if thinking_tags and content == thinking_tags[0]:
# When we see a thinking start tag (which is a single token), we'll build a new thinking part instead
self._vendor_id_to_part_index.pop(vendor_part_id, None)
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

Expand Down
19 changes: 7 additions & 12 deletions pydantic_ai_slim/pydantic_ai/_thinking_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,30 @@

from pydantic_ai.messages import TextPart, ThinkingPart

START_THINK_TAG = '<think>'
END_THINK_TAG = '</think>'


def split_content_into_text_and_thinking(content: str) -> list[ThinkingPart | TextPart]:
def split_content_into_text_and_thinking(content: str, thinking_tags: tuple[str, str]) -> list[ThinkingPart | TextPart]:
"""Split a string into text and thinking parts.

Some models don't return the thinking part as a separate part, but rather as a tag in the content.
This function splits the content into text and thinking parts.

We use the `<think>` tag because that's how Groq uses it in the `raw` format, so instead of using `<Thinking>` or
something else, we just match the tag to make it easier for other models that don't support the `ThinkingPart`.
"""
start_tag, end_tag = thinking_tags
parts: list[ThinkingPart | TextPart] = []

start_index = content.find(START_THINK_TAG)
start_index = content.find(start_tag)
while start_index >= 0:
before_think, content = content[:start_index], content[start_index + len(START_THINK_TAG) :]
before_think, content = content[:start_index], content[start_index + len(start_tag) :]
if before_think:
parts.append(TextPart(content=before_think))
end_index = content.find(END_THINK_TAG)
end_index = content.find(end_tag)
if end_index >= 0:
think_content, content = content[:end_index], content[end_index + len(END_THINK_TAG) :]
think_content, content = content[:end_index], content[end_index + len(end_tag) :]
parts.append(ThinkingPart(content=think_content))
else:
# We lose the `<think>` tag, but it shouldn't matter.
parts.append(TextPart(content=content))
content = ''
start_index = content.find(START_THINK_TAG)
start_index = content.find(start_tag)
if content:
parts.append(TextPart(content=content))
return parts
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
# While Cohere's API returns a list, it only does that for future proofing
# and currently only one item is being returned.
choice = response.message.content[0]
parts.extend(split_content_into_text_and_thinking(choice.text))
parts.extend(split_content_into_text_and_thinking(choice.text, self.profile.thinking_tags))
for c in response.message.tool_calls or []:
if c.function and c.function.name and c.function.arguments: # pragma: no branch
parts.append(
Expand Down
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ToolReturnPart,
UserPromptPart,
)
from ..profiles import ModelProfileSpec
from ..profiles import ModelProfile, ModelProfileSpec
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..tools import ToolDefinition
Expand Down Expand Up @@ -261,7 +261,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
items.append(ThinkingPart(content=choice.message.reasoning))
if choice.message.content is not None:
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
items.extend(split_content_into_text_and_thinking(choice.message.content))
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
Expand All @@ -281,6 +281,7 @@ async def _process_streamed_response(self, response: AsyncStream[chat.ChatComple
return GroqStreamedResponse(
_response=peekable_response,
_model_name=self._model_name,
_model_profile=self.profile,
_timestamp=number_to_datetime(first_chunk.created),
)

Expand Down Expand Up @@ -400,6 +401,7 @@ class GroqStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Groq models."""

_model_name: GroqModelName
_model_profile: ModelProfile
_response: AsyncIterable[chat.ChatCompletionChunk]
_timestamp: datetime

Expand All @@ -416,7 +418,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
if content is not None:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
Expand Down
9 changes: 7 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
UserPromptPart,
VideoUrl,
)
from ..profiles import ModelProfile
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
Expand Down Expand Up @@ -244,7 +245,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
items: list[ModelResponsePart] = []

if content is not None:
items.extend(split_content_into_text_and_thinking(content))
items.extend(split_content_into_text_and_thinking(content, self.profile.thinking_tags))
if tool_calls is not None:
for c in tool_calls:
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
Expand All @@ -267,6 +268,7 @@ async def _process_streamed_response(self, response: AsyncIterable[ChatCompletio

return HuggingFaceStreamedResponse(
_model_name=self._model_name,
_model_profile=self.profile,
_response=peekable_response,
_timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc),
)
Expand Down Expand Up @@ -412,6 +414,7 @@ class HuggingFaceStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Hugging Face models."""

_model_name: str
_model_profile: ModelProfile
_response: AsyncIterable[ChatCompletionStreamOutput]
_timestamp: datetime

Expand All @@ -428,7 +431,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes

parts: list[ModelResponsePart] = []
if text := _map_content(content):
parts.extend(split_content_into_text_and_thinking(text))
parts.extend(split_content_into_text_and_thinking(text, self.profile.thinking_tags))

if isinstance(tool_calls, list):
for tool_call in tool_calls:
Expand Down
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
UserPromptPart,
VideoUrl,
)
from ..profiles import ModelProfileSpec
from ..profiles import ModelProfile, ModelProfileSpec
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
Expand Down Expand Up @@ -407,7 +407,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
}

if choice.message.content is not None:
items.extend(split_content_into_text_and_thinking(choice.message.content))
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
part = ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)
Expand All @@ -433,6 +433,7 @@ async def _process_streamed_response(self, response: AsyncStream[ChatCompletionC

return OpenAIStreamedResponse(
_model_name=self._model_name,
_model_profile=self.profile,
_response=peekable_response,
_timestamp=number_to_datetime(first_chunk.created),
)
Expand Down Expand Up @@ -1009,6 +1010,7 @@ class OpenAIStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for OpenAI models."""

_model_name: OpenAIModelName
_model_profile: ModelProfile
_response: AsyncIterable[ChatCompletionChunk]
_timestamp: datetime

Expand All @@ -1025,7 +1027,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ async def request_stream(

model_response = self._request(messages, model_settings, model_request_parameters)
yield TestStreamedResponse(
_model_name=self._model_name, _structured_response=model_response, _messages=messages
_model_name=self._model_name,
_structured_response=model_response,
_messages=messages,
)

@property
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class ModelProfile:
json_schema_transformer: type[JsonSchemaTransformer] | None = None
"""The transformer to use to make JSON schemas for tools and structured output compatible with the model."""

thinking_tags: tuple[str, str] = ('<think>', '</think>')
"""The tags used to indicate thinking parts in the model's output. Defaults to ('<think>', '</think>')."""

@classmethod
def from_profile(cls, profile: ModelProfile | None) -> Self:
"""Build a ModelProfile subclass instance from a ModelProfile instance."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/profiles/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

def anthropic_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for an Anthropic model."""
return None
return ModelProfile(thinking_tags=('<thinking>', '</thinking>'))
17 changes: 9 additions & 8 deletions tests/test_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,31 @@ def test_handle_dovetailed_text_deltas():

def test_handle_text_deltas_with_think_tags():
manager = ModelResponsePartsManager()
thinking_tags = ('<think>', '</think>')

event = manager.handle_text_delta(vendor_part_id='content', content='pre-', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='pre-', thinking_tags=thinking_tags)
assert event == snapshot(
PartStartEvent(index=0, part=TextPart(content='pre-', part_kind='text'), event_kind='part_start')
)
assert manager.get_parts() == snapshot([TextPart(content='pre-', part_kind='text')])

event = manager.handle_text_delta(vendor_part_id='content', content='thinking', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
assert event == snapshot(
PartDeltaEvent(
index=0, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
)
)
assert manager.get_parts() == snapshot([TextPart(content='pre-thinking', part_kind='text')])

event = manager.handle_text_delta(vendor_part_id='content', content='<think>', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='<think>', thinking_tags=thinking_tags)
assert event == snapshot(
PartStartEvent(index=1, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start')
)
assert manager.get_parts() == snapshot(
[TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='', part_kind='thinking')]
)

event = manager.handle_text_delta(vendor_part_id='content', content='thinking', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
assert event == snapshot(
PartDeltaEvent(
index=1,
Expand All @@ -117,7 +118,7 @@ def test_handle_text_deltas_with_think_tags():
[TextPart(content='pre-thinking', part_kind='text'), ThinkingPart(content='thinking', part_kind='thinking')]
)

event = manager.handle_text_delta(vendor_part_id='content', content=' more', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content=' more', thinking_tags=thinking_tags)
assert event == snapshot(
PartDeltaEvent(
index=1, delta=ThinkingPartDelta(content_delta=' more', part_delta_kind='thinking'), event_kind='part_delta'
Expand All @@ -130,10 +131,10 @@ def test_handle_text_deltas_with_think_tags():
]
)

event = manager.handle_text_delta(vendor_part_id='content', content='</think>', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='</think>', thinking_tags=thinking_tags)
assert event is None

event = manager.handle_text_delta(vendor_part_id='content', content='post-', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='post-', thinking_tags=thinking_tags)
assert event == snapshot(
PartStartEvent(index=2, part=TextPart(content='post-', part_kind='text'), event_kind='part_start')
)
Expand All @@ -145,7 +146,7 @@ def test_handle_text_deltas_with_think_tags():
]
)

event = manager.handle_text_delta(vendor_part_id='content', content='thinking', extract_think_tags=True)
event = manager.handle_text_delta(vendor_part_id='content', content='thinking', thinking_tags=thinking_tags)
assert event == snapshot(
PartDeltaEvent(
index=2, delta=TextPartDelta(content_delta='thinking', part_delta_kind='text'), event_kind='part_delta'
Expand Down
Loading