Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 17 additions & 67 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@

import collections.abc
import copy
import inspect
import logging
import pathlib
import platform
import sys
import warnings
from types import GenericAlias, MappingProxyType
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar

from omegaconf import DictConfig
from packaging import version
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Expand Down Expand Up @@ -89,66 +86,7 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
return deps


# Pydantic 2 has different handling of serialization.
# This requires some workarounds at the moment until the feature is added to easily get a mode that
# is compatible with Pydantic 1
# This is done by adjusting annotations via a MetaClass for any annotation that includes a BaseModel,
# such that the new annotation contains SerializeAsAny
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# https://github.com/pydantic/pydantic-core/pull/740
# See https://github.com/pydantic/pydantic/issues/6381 for inspiration on implementation
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402

_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")


def _adjust_annotations(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
if not _IS_PY39:
from types import UnionType

if origin is UnionType:
origin = Union

if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
return SerializeAsAny[annotation]
elif origin and args:
# Filter out typing.Type and generic types
if origin is type or (inspect.isclass(origin) and issubclass(origin, Generic)):
return annotation
elif origin is ClassVar: # ClassVar doesn't accept a tuple of length 1 in py39
return ClassVar[_adjust_annotations(args[0])]
else:
try:
return origin[tuple(_adjust_annotations(arg) for arg in args)]
except TypeError:
raise TypeError(f"Could not adjust annotations for {origin}")
else:
return annotation


class _SerializeAsAnyMeta(ModelMetaclass):
def __new__(self, name: str, bases: Tuple[type], namespaces: Dict[str, Any], **kwargs):
annotations: dict = namespaces.get("__annotations__", {})

for base in bases:
for base_ in base.__mro__:
if base_ is PydanticBaseModel:
annotations.update(base_.__annotations__)

for field, annotation in annotations.items():
if not field.startswith("__"):
annotations[field] = _adjust_annotations(annotation)

namespaces["__annotations__"] = annotations

return super().__new__(self, name, bases, namespaces, **kwargs)


class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta):
class BaseModel(PydanticBaseModel, _RegistryMixin):
"""BaseModel is a base class for all pydantic models within the cubist flow framework.

This gives us a way to add functionality to the framework, including
Expand Down Expand Up @@ -182,6 +120,17 @@ def type_(self) -> PyObjectPath:
ser_json_timedelta="float",
)

# https://docs.pydantic.dev/latest/concepts/serialization/#overriding-the-serialize_as_any-default-false
def model_dump(self, **kwargs) -> dict[str, Any]:
if not kwargs.get("serialize_as_any"):
kwargs["serialize_as_any"] = True
return super().model_dump(**kwargs)

def model_dump_json(self, **kwargs) -> str:
if not kwargs.get("serialize_as_any"):
kwargs["serialize_as_any"] = True
return super().model_dump_json(**kwargs)

def __str__(self):
# Because the standard string representation does not include class name
return repr(self)
Expand Down Expand Up @@ -251,7 +200,7 @@ def _base_model_validator(cls, v, handler, info):

if isinstance(v, PydanticBaseModel):
# Coerce from one BaseModel type to another (because it worked automatically in v1)
v = v.model_dump(exclude={"type_"})
v = v.model_dump(serialize_as_any=True, exclude={"type_"})

return handler(v)

Expand Down Expand Up @@ -376,7 +325,8 @@ def _validate_name(cls, v):
@model_serializer(mode="wrap")
def _registry_serializer(self, handler):
values = handler(self)
values["models"] = self._models
models_serialized = {k: model.model_dump(serialize_as_any=True, by_alias=True) for k, model in self._models.items()}
values["models"] = models_serialized
return values

@property
Expand Down
43 changes: 42 additions & 1 deletion ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@
from inspect import Signature, isclass, signature
from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
from pydantic import (
BaseModel as PydanticBaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
SerializerFunctionWrapHandler,
TypeAdapter,
field_validator,
model_serializer,
model_validator,
)
from typing_extensions import override

from .base import (
Expand Down Expand Up @@ -426,6 +437,36 @@ def __call__(self) -> ResultType:
else:
return fn(self.context)

# When serialize_as_any=True, pydantic may detect repeated object ids in nested graphs
# (e.g., shared default lists) and raise a circular reference error during serialization.
# For computing cache keys, fall back to a minimal, stable representation if such an error occurs.
# This is similar to how we the pydantic docs here:
# https://docs.pydantic.dev/latest/concepts/forward_annotations/#cyclic-references
# handle cyclic references during serialization.
@model_serializer(mode="wrap")
def _serialize_model_evaluation_context(self, handler: SerializerFunctionWrapHandler):
try:
return handler(self)
except ValueError as exc:
msg = str(exc)
if "Circular reference" not in msg and "id repeated" not in msg:
raise
# Minimal, stable representation sufficient for cache-key tokenization
try:
model_repr = self.model.model_dump(mode="python", serialize_as_any=True, by_alias=True)
except Exception:
model_repr = repr(self.model)
try:
context_repr = self.context.model_dump(mode="python", serialize_as_any=True, by_alias=True)
except Exception:
context_repr = repr(self.context)
return dict(
fn=self.fn,
model=model_repr,
context=context_repr,
options=dict(self.options),
)


class EvaluatorBase(_CallableModel, abc.ABC):
"""Base class for evaluators, which are higher-order models that evaluate ModelAndContext.
Expand Down
43 changes: 1 addition & 42 deletions ccflow/tests/test_base_serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pickle
import platform
import unittest
from typing import Annotated, ClassVar, Dict, List, Optional, Type, Union
from typing import Annotated, Optional

import numpy as np
from packaging import version
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, ValidationError

from ccflow import BaseModel, NDArray
Expand Down Expand Up @@ -213,45 +211,6 @@ class C(PydanticBaseModel):
# C implements the normal pydantic BaseModel whichhould allow extra fields.
_ = C(extra_field1=1)

def test_serialize_as_any(self):
# https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing
# https://github.com/pydantic/pydantic/issues/6423
# This test could be removed once there is a different solution to the issue above
from pydantic import SerializeAsAny
from pydantic.types import constr

if version.parse(platform.python_version()) >= version.parse("3.10"):
pipe_union = A | int
else:
pipe_union = Union[A, int]

class MyNestedModel(BaseModel):
a1: A
a2: Optional[Union[A, int]]
a3: Dict[str, Optional[List[A]]]
a4: ClassVar[A]
a5: Type[A]
a6: constr(min_length=1)
a7: pipe_union

target = {
"a1": SerializeAsAny[A],
"a2": Optional[Union[SerializeAsAny[A], int]],
"a4": ClassVar[SerializeAsAny[A]],
"a5": Type[A],
"a6": constr(min_length=1), # Uses Annotation
"a7": Union[SerializeAsAny[A], int],
}
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
annotations = MyNestedModel.__annotations__
self.assertEqual(str(annotations["a1"]), str(target["a1"]))
self.assertEqual(str(annotations["a2"]), str(target["a2"]))
self.assertEqual(str(annotations["a3"]), str(target["a3"]))
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
self.assertEqual(str(annotations["a7"]), str(target["a7"]))

def test_pickle_consistency(self):
model = MultiAttributeModel(z=1, y="test", x=3.14, w=True)
serialized = pickle.dumps(model)
Expand Down
83 changes: 83 additions & 0 deletions ccflow/tests/test_evaluation_context_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from datetime import date

from ccflow import DateContext
from ccflow.callable import ModelEvaluationContext
from ccflow.evaluators import GraphEvaluator, LoggingEvaluator, MultiEvaluator
from ccflow.tests.evaluators.util import NodeModel

# NOTE: for these tests, round-tripping via JSON does not work
# because the ModelEvaluationContext just has an InstanceOf validation check
# and so we do not actually construct a full MEC on load.


def _make_nested_mec(model):
ctx = DateContext(date=date(2022, 1, 1))
mec = model.__call__.get_evaluation_context(model, ctx)
assert isinstance(mec, ModelEvaluationContext)
# ensure nested: outer model is an evaluator, inner is a ModelEvaluationContext
assert isinstance(mec.context, ModelEvaluationContext)
return mec


def test_mec_model_dump_basic():
m = NodeModel()
mec = _make_nested_mec(m)

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d and "options" in d

s = mec.model_dump_json()
parsed = json.loads(s)
assert parsed["fn"] == d["fn"]
# Also verify mode-specific dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_diamond_graph():
n0 = NodeModel()
n1 = NodeModel(deps_model=[n0])
n2 = NodeModel(deps_model=[n0])
root = NodeModel(deps_model=[n1, n2])

mec = _make_nested_mec(root)

d = mec.model_dump()
assert isinstance(d, dict)
assert set(["fn", "model", "context", "options"]).issubset(d.keys())

s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)


def test_mec_model_dump_with_multi_evaluator():
m = NodeModel()
_ = LoggingEvaluator() # ensure import/validation
evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(), GraphEvaluator()])

# Simulate how Flow builds evaluation context with a custom evaluator
ctx = DateContext(date=date(2022, 1, 1))
mec = ModelEvaluationContext(model=evaluator, context=m.__call__.get_evaluation_context(m, ctx))

d = mec.model_dump()
assert isinstance(d, dict)
assert "fn" in d and "model" in d and "context" in d
s = mec.model_dump_json()
json.loads(s)
# verify mode dumps
d_py = mec.model_dump(mode="python")
assert isinstance(d_py, dict)
d_json = mec.model_dump(mode="json")
assert isinstance(d_json, dict)
json.dumps(d_json)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"orjson",
"pandas",
"pyarrow",
"pydantic>=2.6,<3",
"pydantic>=2.12,<3",
"smart_open",
"tenacity",
]
Expand Down
Loading