Skip to content

Commit

Permalink
fix: support json safe serialization for basemodel subclasses (#727)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-bot committed Nov 4, 2024
1 parent 62bb863 commit 5be855e
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/anthropic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self
from typing_extensions import Self, Literal

import pydantic
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -137,9 +137,11 @@ def model_dump(
exclude_unset: bool = False,
exclude_defaults: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
if PYDANTIC_V2:
if PYDANTIC_V2 or hasattr(model, "model_dump"):
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
Expand Down
9 changes: 6 additions & 3 deletions src/anthropic/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PropertyInfo,
is_list,
is_given,
json_safe,
lru_cache,
is_mapping,
parse_date,
Expand Down Expand Up @@ -279,8 +280,8 @@ def model_dump(
Returns:
A dictionary representation of the model.
"""
if mode != "python":
raise ValueError("mode is only supported in Pydantic v2")
if mode not in {"json", "python"}:
raise ValueError("mode must be either 'json' or 'python'")
if round_trip != False:
raise ValueError("round_trip is only supported in Pydantic v2")
if warnings != True:
Expand All @@ -289,7 +290,7 @@ def model_dump(
raise ValueError("context is only supported in Pydantic v2")
if serialize_as_any != False:
raise ValueError("serialize_as_any is only supported in Pydantic v2")
return super().dict( # pyright: ignore[reportDeprecated]
dumped = super().dict( # pyright: ignore[reportDeprecated]
include=include,
exclude=exclude,
by_alias=by_alias,
Expand All @@ -298,6 +299,8 @@ def model_dump(
exclude_none=exclude_none,
)

return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped

@override
def model_dump_json(
self,
Expand Down
1 change: 1 addition & 0 deletions src/anthropic/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
is_list as is_list,
is_given as is_given,
is_tuple as is_tuple,
json_safe as json_safe,
lru_cache as lru_cache,
is_mapping as is_mapping,
is_tuple_t as is_tuple_t,
Expand Down
4 changes: 2 additions & 2 deletions src/anthropic/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _transform_recursive(
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
return model_dump(data, exclude_unset=True, mode="json")

annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
Expand Down Expand Up @@ -329,7 +329,7 @@ async def _async_transform_recursive(
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
return model_dump(data, exclude_unset=True, mode="json")

annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
Expand Down
17 changes: 17 additions & 0 deletions src/anthropic/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard

import sniffio
Expand Down Expand Up @@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
maxsize=maxsize,
)
return cast(Any, wrapper) # type: ignore[no-any-return]


def json_safe(data: object) -> object:
"""Translates a mapping / sequence recursively in the same fashion
as `pydantic` v2's `model_dump(mode="json")`.
"""
if is_mapping(data):
return {json_safe(key): json_safe(value) for key, value in data.items()}

if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
return [json_safe(item) for item in data]

if isinstance(data, (datetime, date)):
return data.isoformat()

return data
21 changes: 7 additions & 14 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,19 +520,15 @@ class Model(BaseModel):
assert m3.to_dict(exclude_none=True) == {}
assert m3.to_dict(exclude_defaults=True) == {}

if PYDANTIC_V2:

class Model2(BaseModel):
created_at: datetime
class Model2(BaseModel):
created_at: datetime

time_str = "2024-03-21T11:39:01.275859"
m4 = Model2.construct(created_at=time_str)
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
assert m4.to_dict(mode="json") == {"created_at": time_str}
else:
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
m.to_dict(mode="json")
time_str = "2024-03-21T11:39:01.275859"
m4 = Model2.construct(created_at=time_str)
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
assert m4.to_dict(mode="json") == {"created_at": time_str}

if not PYDANTIC_V2:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_dict(warnings=False)

Expand All @@ -558,9 +554,6 @@ class Model(BaseModel):
assert m3.model_dump(exclude_none=True) == {}

if not PYDANTIC_V2:
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
m.model_dump(mode="json")

with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
m.model_dump(round_trip=True)

Expand Down
15 changes: 15 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False):
foo: Annotated[date, PropertyInfo(format="iso8601")]


class DatetimeModel(BaseModel):
foo: datetime


class DateModel(BaseModel):
foo: Optional[date]


@parametrize
@pytest.mark.asyncio
async def test_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
tz = "Z" if PYDANTIC_V2 else "+00:00"
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap]

dt = dt.replace(tzinfo=None)
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]

assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap]
assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore
assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == {
"foo": "2023-02-23"
} # type: ignore[comparison-overlap]


@parametrize
Expand Down

0 comments on commit 5be855e

Please sign in to comment.