Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 108 additions & 122 deletions ccflow/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,50 @@ class _CallableModel(BaseModel, abc.ABC):
)
meta: MetaData = Field(default_factory=MetaData)

@classmethod
def _check_context_type(cls, context_type):
type_call_arg = _cached_signature(cls.__call__).parameters["context"].annotation

# If optional type, extract inner type
if get_origin(type_call_arg) is Optional or (get_origin(type_call_arg) is Union and type(None) in get_args(type_call_arg)):
type_call_arg = [t for t in get_args(type_call_arg) if t is not type(None)][0]

if (
not isinstance(type_call_arg, TypeVar)
and type_call_arg is not Signature.empty
and (not isclass(type_call_arg) or not issubclass(type_call_arg, context_type))
and (not isclass(context_type) or not issubclass(context_type, type_call_arg))
):
err_msg_type_mismatch = f"The context_type {context_type} must match the type of the context accepted by __call__ {type_call_arg}"
raise ValueError(err_msg_type_mismatch)

@classmethod
def _check_result_type(cls, result_type):
type_call_return = _cached_signature(cls.__call__).return_annotation

# If union, check all types
if get_origin(type_call_return) is Union and get_args(type_call_return):
types_call_return = [t for t in get_args(type_call_return) if t is not type(None)]
else:
types_call_return = [type_call_return]

all_bad = True
for type_call_return in types_call_return:
if (
not isinstance(type_call_return, TypeVar)
and type_call_return is not Signature.empty
and (not isclass(type_call_return) or not issubclass(type_call_return, result_type))
and (not isclass(result_type) or not issubclass(result_type, type_call_return))
):
# Don't invert logic so that we match context above
pass
else:
all_bad = False

if all_bad:
err_msg_type_mismatch = f"The result_type {result_type} must match the return type of __call__ {type_call_return}"
raise ValueError(err_msg_type_mismatch)

@model_validator(mode="after")
def _check_signature(self):
sig_call = _cached_signature(self.__class__.__call__)
Expand All @@ -98,50 +142,12 @@ def _check_signature(self):
)
raise ValueError(err_msg_type_mismatch)

# If context_type or result_type are overridden, ensure they match the signature
# If context_type or result_type are overridden or
# come from generic type, ensure they match the signature
if hasattr(self, "context_type"):
type_call_arg = _cached_signature(self.__class__.__call__).parameters["context"].annotation

# If optional type, extract inner type
if get_origin(type_call_arg) is Optional or (get_origin(type_call_arg) is Union and type(None) in get_args(type_call_arg)):
type_call_arg = [t for t in get_args(type_call_arg) if t is not type(None)][0]

if (
not isinstance(type_call_arg, TypeVar)
and type_call_arg is not Signature.empty
and (not isclass(type_call_arg) or not issubclass(type_call_arg, self.context_type))
and (not isclass(self.context_type) or not issubclass(self.context_type, type_call_arg))
):
err_msg_type_mismatch = (
f"The context_type {self.context_type} must match the type of the context accepted by __call__ {type_call_arg}"
)
raise ValueError(err_msg_type_mismatch)

self._check_context_type(self.context_type)
if hasattr(self, "result_type"):
type_call_return = _cached_signature(self.__class__.__call__).return_annotation

# If union, check all types
if get_origin(type_call_return) is Union and get_args(type_call_return):
types_call_return = [t for t in get_args(type_call_return) if t is not type(None)]
else:
types_call_return = [type_call_return]

all_bad = True
for type_call_return in types_call_return:
if (
not isinstance(type_call_return, TypeVar)
and type_call_return is not Signature.empty
and (not isclass(type_call_return) or not issubclass(type_call_return, self.result_type))
and (not isclass(self.result_type) or not issubclass(self.result_type, type_call_return))
):
# Don't invert logic so that we match context above
pass
else:
all_bad = False

if all_bad:
err_msg_type_mismatch = f"The result_type {self.result_type} must match the return type of __call__ {type_call_return}"
raise ValueError(err_msg_type_mismatch)
self._check_result_type(self.result_type)

return self

Expand Down Expand Up @@ -548,13 +554,17 @@ def context_type(self) -> Type[ContextType]:
"""
typ = _cached_signature(self.__class__.__call__).parameters["context"].annotation
if typ is Signature.empty:
raise TypeError("Must either define a type annotation for context on __call__ or implement 'context_type'")

self._check_context_type(typ)
return typ
if isinstance(self, CallableModelGenericType) and hasattr(self, "_context_generic_type"):
typ = self._context_generic_type
else:
raise TypeError("Must either define a type annotation for context on __call__ or implement 'context_type'")
elif (
isinstance(self, CallableModelGenericType) and hasattr(self, "_context_generic_type") and not issubclass(typ, self._context_generic_type)
):
raise TypeError(
f"Context type annotation {typ} on __call__ does not match context_type {self._context_generic_type} defined by CallableModelGenericType"
)

@staticmethod
def _check_context_type(typ):
# If optional type, extract inner type
if get_origin(typ) is Optional or (get_origin(typ) is Union and type(None) in get_args(typ)):
type_to_check = [t for t in get_args(typ) if t is not type(None)][0]
Expand All @@ -565,6 +575,8 @@ def _check_context_type(typ):
if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase):
raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.")

return typ

@property
def result_type(self) -> Type[ResultType]:
"""Return the result type for the model.
Expand All @@ -574,13 +586,29 @@ def result_type(self) -> Type[ResultType]:
"""
typ = _cached_signature(self.__class__.__call__).return_annotation
if typ is Signature.empty:
raise TypeError("Must either define a return type annotation on __call__ or implement 'result_type'")

self._check_result_type(typ)
return typ
if isinstance(self, CallableModelGenericType) and hasattr(self, "_result_generic_type"):
typ = self._result_generic_type
else:
raise TypeError("Must either define a return type annotation on __call__ or implement 'result_type'")
elif isinstance(self, CallableModelGenericType) and hasattr(self, "_result_generic_type"):
if get_origin(typ) is Union and get_origin(self._result_generic_type) is Union:
if set(get_args(typ)) != set(get_args(self._result_generic_type)):
raise TypeError(
f"Return type annotation {typ} on __call__ does not match result_type {self._result_generic_type} defined by CallableModelGenericType"
)
elif get_origin(typ) is Union:
raise NotImplementedError(
"Return type annotation on __call__ is a Union, but result_type defined by CallableModelGenericType is not a Union. This case is not yet supported."
)
elif get_origin(self._result_generic_type) is Union:
raise NotImplementedError(
"Return type annotation on __call__ is not a Union, but result_type defined by CallableModelGenericType is a Union. This case is not yet supported."
)
elif not issubclass(typ, self._result_generic_type):
raise TypeError(
f"Return type annotation {typ} on __call__ does not match result_type {self._result_generic_type} defined by CallableModelGenericType"
)

@staticmethod
def _check_result_type(typ):
# If union type, extract inner type
if get_origin(typ) is Union:
raise TypeError(
Expand All @@ -590,6 +618,7 @@ def _check_result_type(typ):
# Ensure subclass of ResultBase
if not isclass(typ) or not issubclass(typ, ResultBase):
raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.")
return typ

@Flow.deps
def __deps__(
Expand Down Expand Up @@ -625,51 +654,48 @@ def result_type(self) -> Type[ResultType]:
return self.model.result_type


class CallableModelGenericType(CallableModel, Generic[ContextType, ResultType]):
class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]):
"""Special type of callable model that provides context and result via
a generic type instead of annotations on __call__.
"""

_context_type: ClassVar[Type[ContextType]]
_result_type: ClassVar[Type[ResultType]]

@property
def context_type(self) -> Type[ContextType]:
return self._context_type

@property
def result_type(self) -> Type[ResultType]:
return self._result_type
_context_generic_type: ClassVar[Type[ContextType]]
_result_generic_type: ClassVar[Type[ResultType]]

def __setstate__(self, state):
super().__setstate__(state)
self._determine_context_result()
super().__setstate__(state)

@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
super().__pydantic_init_subclass__(**kwargs)
cls._determine_context_result()

@classmethod
def _determine_context_result(cls):
# Extract the generic types from the class definition
if not hasattr(cls, "_context_type") or not hasattr(cls, "_result_type"):
if not hasattr(cls, "_context_generic_type") or not hasattr(cls, "_result_generic_type"):
new_context_type = None
new_result_type = None

for base in cls.__mro__:
if issubclass(base, CallableModelGenericType):
if issubclass(base, CallableModelGeneric):
# Found the generic base class, it should
# have either generic parameters or context/result
if new_context_type is None and hasattr(base, "_context_type") and issubclass(base._context_type, ContextBase):
new_context_type = base._context_type
if new_context_type is None and hasattr(base, "_context_generic_type") and issubclass(base._context_generic_type, ContextBase):
new_context_type = base._context_generic_type
if (
new_result_type is None
and hasattr(base, "_result_type")
and hasattr(base, "_result_generic_type")
and (
issubclass(base._result_type, ResultBase)
issubclass(base._result_generic_type, ResultBase)
or (
get_origin(base._result_type) is Union
and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(base._result_type))
get_origin(base._result_generic_type) is Union
and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(base._result_generic_type))
)
)
):
new_result_type = base._result_type
new_result_type = base._result_generic_type
if base.__pydantic_generic_metadata__["args"]:
if len(base.__pydantic_generic_metadata__["args"]) >= 2:
# Assume order is ContextType, ResultType
Expand All @@ -696,56 +722,12 @@ def _determine_context_result(cls):
break

if new_context_type is not None:
# Validate that the model's context_type match
annotation_context_type = _cached_signature(cls.__call__).parameters["context"].annotation
if get_origin(annotation_context_type) is Optional or (
get_origin(annotation_context_type) is Union and type(None) in get_args(annotation_context_type)
):
annotation_context_type = [t for t in get_args(annotation_context_type) if t is not type(None)][0]
if (
annotation_context_type is not Signature.empty
and not isinstance(annotation_context_type, TypeVar)
and not issubclass(annotation_context_type, new_context_type)
):
raise TypeError(
f"Context type annotation {annotation_context_type} on __call__ does not match context_type {new_context_type} defined by CallableModelGenericType"
)
elif isclass(annotation_context_type) and issubclass(annotation_context_type, new_context_type):
new_context_type = annotation_context_type

# Set on class
cls._context_type = new_context_type
cls._context_generic_type = new_context_type

if new_result_type is not None:
# Validate that the model's result_type match
annotation_result_type = _cached_signature(cls.__call__).return_annotation
if annotation_result_type is Signature.empty:
...
elif isinstance(annotation_result_type, TypeVar):
...
elif get_origin(annotation_result_type) is Union and get_origin(new_result_type) is Union:
raise TypeError(
f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received {annotation_result_type}"
)
elif get_origin(annotation_result_type) is Union:
if not any(issubclass(new_result_type, union_type) for union_type in get_args(annotation_result_type)):
raise TypeError(
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
)
elif get_origin(new_result_type) is Union:
if not any(issubclass(annotation_result_type, union_type) for union_type in get_args(new_result_type)):
raise TypeError(
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
)
elif not issubclass(annotation_result_type, new_result_type):
raise TypeError(
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
)
elif isclass(annotation_result_type) and issubclass(annotation_result_type, new_result_type):
new_result_type = annotation_result_type

# Set on class
cls._result_type = new_result_type
cls._result_generic_type = new_result_type

@model_validator(mode="wrap")
def _validate_callable_model_generic_type(cls, m, handler, info):
Expand All @@ -756,7 +738,8 @@ def _validate_callable_model_generic_type(cls, m, handler, info):

if isinstance(m, dict):
m = handler(m)
cls._determine_context_result()
elif isinstance(m, cls):
m = handler(m)

# Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
if not isinstance(m, CallableModel):
Expand All @@ -768,3 +751,6 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)

return m


CallableModelGenericType = CallableModelGeneric
21 changes: 19 additions & 2 deletions ccflow/tests/test_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ def __call__(self, context: NullContext) -> Union[AResult, BResult]:
return AResult(a=1)


class UnionReturnGeneric(CallableModelGenericType[NullContext, AResult]):
class UnionReturnGeneric(CallableModelGenericType[NullContext, Union[AResult, BResult]]):
@property
def result_type(self) -> Type[ResultType]:
return AResult

@Flow.call
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
# Return one branch of the Union
Expand Down Expand Up @@ -644,7 +648,7 @@ def test_types_generic(self):
error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedResultAndCall)

error = "Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`"
error = "Model __call__ signature result type cannot be a Union type without a concrete property. Please define a property 'result_type' on the model."
self.assertRaisesRegex(TypeError, error, BadModelUnionReturnGeneric)

def test_union_return_generic(self):
Expand All @@ -653,6 +657,19 @@ def test_union_return_generic(self):
self.assertIsInstance(result, AResult)
self.assertEqual(result.a, 1)

def test_generic_validates_assignment(self):
class MyCallable(CallableModelGenericType[NullContext, GenericResult[int]]):
x: int = 1

@Flow.call
def __call__(self, context: NullContext) -> GenericResult[int]:
self.x = 5
assert self.x == 5
return GenericResult[float](value=self.x)

m = MyCallable()
self.assertEqual(m(NullContext()).value, 5)


class TestCallableModelDeps(TestCase):
def test_basic(self):
Expand Down