Skip to content

Commit 56747e9

Browse files
committed
Simplify callable and generic callable validation to unify, fix bug related to assignment and validators
Signed-off-by: Tim Paine <[email protected]>
1 parent cb9d44a commit 56747e9

File tree

2 files changed

+129
-120
lines changed

2 files changed

+129
-120
lines changed

ccflow/callable.py

Lines changed: 110 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,50 @@ class _CallableModel(BaseModel, abc.ABC):
7979
)
8080
meta: MetaData = Field(default_factory=MetaData)
8181

82+
@classmethod
83+
def _check_context_type(cls, context_type):
84+
type_call_arg = _cached_signature(cls.__call__).parameters["context"].annotation
85+
86+
# If optional type, extract inner type
87+
if get_origin(type_call_arg) is Optional or (get_origin(type_call_arg) is Union and type(None) in get_args(type_call_arg)):
88+
type_call_arg = [t for t in get_args(type_call_arg) if t is not type(None)][0]
89+
90+
if (
91+
not isinstance(type_call_arg, TypeVar)
92+
and type_call_arg is not Signature.empty
93+
and (not isclass(type_call_arg) or not issubclass(type_call_arg, context_type))
94+
and (not isclass(context_type) or not issubclass(context_type, type_call_arg))
95+
):
96+
err_msg_type_mismatch = f"The context_type {context_type} must match the type of the context accepted by __call__ {type_call_arg}"
97+
raise ValueError(err_msg_type_mismatch)
98+
99+
@classmethod
100+
def _check_result_type(cls, result_type):
101+
type_call_return = _cached_signature(cls.__call__).return_annotation
102+
103+
# If union, check all types
104+
if get_origin(type_call_return) is Union and get_args(type_call_return):
105+
types_call_return = [t for t in get_args(type_call_return) if t is not type(None)]
106+
else:
107+
types_call_return = [type_call_return]
108+
109+
all_bad = True
110+
for type_call_return in types_call_return:
111+
if (
112+
not isinstance(type_call_return, TypeVar)
113+
and type_call_return is not Signature.empty
114+
and (not isclass(type_call_return) or not issubclass(type_call_return, result_type))
115+
and (not isclass(result_type) or not issubclass(result_type, type_call_return))
116+
):
117+
# Don't invert logic so that we match context above
118+
pass
119+
else:
120+
all_bad = False
121+
122+
if all_bad:
123+
err_msg_type_mismatch = f"The result_type {result_type} must match the return type of __call__ {type_call_return}"
124+
raise ValueError(err_msg_type_mismatch)
125+
82126
@model_validator(mode="after")
83127
def _check_signature(self):
84128
sig_call = _cached_signature(self.__class__.__call__)
@@ -98,50 +142,12 @@ def _check_signature(self):
98142
)
99143
raise ValueError(err_msg_type_mismatch)
100144

101-
# If context_type or result_type are overridden, ensure they match the signature
145+
# If context_type or result_type are overridden or
146+
# come from generic type, ensure they match the signature
102147
if hasattr(self, "context_type"):
103-
type_call_arg = _cached_signature(self.__class__.__call__).parameters["context"].annotation
104-
105-
# If optional type, extract inner type
106-
if get_origin(type_call_arg) is Optional or (get_origin(type_call_arg) is Union and type(None) in get_args(type_call_arg)):
107-
type_call_arg = [t for t in get_args(type_call_arg) if t is not type(None)][0]
108-
109-
if (
110-
not isinstance(type_call_arg, TypeVar)
111-
and type_call_arg is not Signature.empty
112-
and (not isclass(type_call_arg) or not issubclass(type_call_arg, self.context_type))
113-
and (not isclass(self.context_type) or not issubclass(self.context_type, type_call_arg))
114-
):
115-
err_msg_type_mismatch = (
116-
f"The context_type {self.context_type} must match the type of the context accepted by __call__ {type_call_arg}"
117-
)
118-
raise ValueError(err_msg_type_mismatch)
119-
148+
self._check_context_type(self.context_type)
120149
if hasattr(self, "result_type"):
121-
type_call_return = _cached_signature(self.__class__.__call__).return_annotation
122-
123-
# If union, check all types
124-
if get_origin(type_call_return) is Union and get_args(type_call_return):
125-
types_call_return = [t for t in get_args(type_call_return) if t is not type(None)]
126-
else:
127-
types_call_return = [type_call_return]
128-
129-
all_bad = True
130-
for type_call_return in types_call_return:
131-
if (
132-
not isinstance(type_call_return, TypeVar)
133-
and type_call_return is not Signature.empty
134-
and (not isclass(type_call_return) or not issubclass(type_call_return, self.result_type))
135-
and (not isclass(self.result_type) or not issubclass(self.result_type, type_call_return))
136-
):
137-
# Don't invert logic so that we match context above
138-
pass
139-
else:
140-
all_bad = False
141-
142-
if all_bad:
143-
err_msg_type_mismatch = f"The result_type {self.result_type} must match the return type of __call__ {type_call_return}"
144-
raise ValueError(err_msg_type_mismatch)
150+
self._check_result_type(self.result_type)
145151

146152
return self
147153

@@ -548,13 +554,15 @@ def context_type(self) -> Type[ContextType]:
548554
"""
549555
typ = _cached_signature(self.__class__.__call__).parameters["context"].annotation
550556
if typ is Signature.empty:
551-
raise TypeError("Must either define a type annotation for context on __call__ or implement 'context_type'")
552-
553-
self._check_context_type(typ)
554-
return typ
557+
if isinstance(self, CallableModelGenericType) and hasattr(self, "context_generic_type"):
558+
typ = self.context_generic_type
559+
else:
560+
raise TypeError("Must either define a type annotation for context on __call__ or implement 'context_type'")
561+
elif isinstance(self, CallableModelGenericType) and hasattr(self, "context_generic_type") and not issubclass(typ, self.context_generic_type):
562+
raise TypeError(
563+
f"Context type annotation {typ} on __call__ does not match context_type {self.context_generic_type} defined by CallableModelGenericType"
564+
)
555565

556-
@staticmethod
557-
def _check_context_type(typ):
558566
# If optional type, extract inner type
559567
if get_origin(typ) is Optional or (get_origin(typ) is Union and type(None) in get_args(typ)):
560568
type_to_check = [t for t in get_args(typ) if t is not type(None)][0]
@@ -565,6 +573,8 @@ def _check_context_type(typ):
565573
if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase):
566574
raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.")
567575

576+
return typ
577+
568578
@property
569579
def result_type(self) -> Type[ResultType]:
570580
"""Return the result type for the model.
@@ -574,13 +584,29 @@ def result_type(self) -> Type[ResultType]:
574584
"""
575585
typ = _cached_signature(self.__class__.__call__).return_annotation
576586
if typ is Signature.empty:
577-
raise TypeError("Must either define a return type annotation on __call__ or implement 'result_type'")
578-
579-
self._check_result_type(typ)
580-
return typ
587+
if isinstance(self, CallableModelGenericType) and hasattr(self, "result_generic_type"):
588+
typ = self.result_generic_type
589+
else:
590+
raise TypeError("Must either define a return type annotation on __call__ or implement 'result_type'")
591+
elif isinstance(self, CallableModelGenericType) and hasattr(self, "result_generic_type"):
592+
if get_origin(typ) is Union and get_origin(self.result_generic_type) is Union:
593+
if set(get_args(typ)) != set(get_args(self.result_generic_type)):
594+
raise TypeError(
595+
f"Return type annotation {typ} on __call__ does not match result_type {self.result_generic_type} defined by CallableModelGenericType"
596+
)
597+
elif get_origin(typ) is Union:
598+
raise NotImplementedError(
599+
"Return type annotation on __call__ is a Union, but result_type defined by CallableModelGenericType is not a Union. This case is not yet supported."
600+
)
601+
elif get_origin(self.result_generic_type) is Union:
602+
raise NotImplementedError(
603+
"Return type annotation on __call__ is not a Union, but result_type defined by CallableModelGenericType is a Union. This case is not yet supported."
604+
)
605+
elif not issubclass(typ, self.result_generic_type):
606+
raise TypeError(
607+
f"Return type annotation {typ} on __call__ does not match result_type {self.result_generic_type} defined by CallableModelGenericType"
608+
)
581609

582-
@staticmethod
583-
def _check_result_type(typ):
584610
# If union type, extract inner type
585611
if get_origin(typ) is Union:
586612
raise TypeError(
@@ -590,6 +616,7 @@ def _check_result_type(typ):
590616
# Ensure subclass of ResultBase
591617
if not isclass(typ) or not issubclass(typ, ResultBase):
592618
raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.")
619+
return typ
593620

594621
@Flow.deps
595622
def __deps__(
@@ -625,51 +652,56 @@ def result_type(self) -> Type[ResultType]:
625652
return self.model.result_type
626653

627654

628-
class CallableModelGenericType(CallableModel, Generic[ContextType, ResultType]):
655+
class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]):
629656
"""Special type of callable model that provides context and result via
630657
a generic type instead of annotations on __call__.
631658
"""
632659

633-
_context_type: ClassVar[Type[ContextType]]
634-
_result_type: ClassVar[Type[ResultType]]
660+
_context_generic_type: ClassVar[Type[ContextType]]
661+
_result_generic_type: ClassVar[Type[ResultType]]
635662

636663
@property
637-
def context_type(self) -> Type[ContextType]:
638-
return self._context_type
664+
def context_generic_type(self) -> Type[ContextType]:
665+
return self._context_generic_type
639666

640667
@property
641-
def result_type(self) -> Type[ResultType]:
642-
return self._result_type
668+
def result_generic_type(self) -> Type[ResultType]:
669+
return self._result_generic_type
643670

644671
def __setstate__(self, state):
645-
super().__setstate__(state)
646672
self._determine_context_result()
673+
super().__setstate__(state)
674+
675+
@classmethod
676+
def __pydantic_init_subclass__(cls, **kwargs):
677+
super().__pydantic_init_subclass__(**kwargs)
678+
cls._determine_context_result()
647679

648680
@classmethod
649681
def _determine_context_result(cls):
650682
# Extract the generic types from the class definition
651-
if not hasattr(cls, "_context_type") or not hasattr(cls, "_result_type"):
683+
if not hasattr(cls, "_context_generic_type") or not hasattr(cls, "_result_generic_type"):
652684
new_context_type = None
653685
new_result_type = None
654686

655687
for base in cls.__mro__:
656-
if issubclass(base, CallableModelGenericType):
688+
if issubclass(base, CallableModelGeneric):
657689
# Found the generic base class, it should
658690
# have either generic parameters or context/result
659-
if new_context_type is None and hasattr(base, "_context_type") and issubclass(base._context_type, ContextBase):
660-
new_context_type = base._context_type
691+
if new_context_type is None and hasattr(base, "_context_generic_type") and issubclass(base._context_generic_type, ContextBase):
692+
new_context_type = base._context_generic_type
661693
if (
662694
new_result_type is None
663-
and hasattr(base, "_result_type")
695+
and hasattr(base, "_result_generic_type")
664696
and (
665-
issubclass(base._result_type, ResultBase)
697+
issubclass(base._result_generic_type, ResultBase)
666698
or (
667-
get_origin(base._result_type) is Union
668-
and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(base._result_type))
699+
get_origin(base._result_generic_type) is Union
700+
and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(base._result_generic_type))
669701
)
670702
)
671703
):
672-
new_result_type = base._result_type
704+
new_result_type = base._result_generic_type
673705
if base.__pydantic_generic_metadata__["args"]:
674706
if len(base.__pydantic_generic_metadata__["args"]) >= 2:
675707
# Assume order is ContextType, ResultType
@@ -696,56 +728,12 @@ def _determine_context_result(cls):
696728
break
697729

698730
if new_context_type is not None:
699-
# Validate that the model's context_type match
700-
annotation_context_type = _cached_signature(cls.__call__).parameters["context"].annotation
701-
if get_origin(annotation_context_type) is Optional or (
702-
get_origin(annotation_context_type) is Union and type(None) in get_args(annotation_context_type)
703-
):
704-
annotation_context_type = [t for t in get_args(annotation_context_type) if t is not type(None)][0]
705-
if (
706-
annotation_context_type is not Signature.empty
707-
and not isinstance(annotation_context_type, TypeVar)
708-
and not issubclass(annotation_context_type, new_context_type)
709-
):
710-
raise TypeError(
711-
f"Context type annotation {annotation_context_type} on __call__ does not match context_type {new_context_type} defined by CallableModelGenericType"
712-
)
713-
elif isclass(annotation_context_type) and issubclass(annotation_context_type, new_context_type):
714-
new_context_type = annotation_context_type
715-
716731
# Set on class
717-
cls._context_type = new_context_type
732+
cls._context_generic_type = new_context_type
718733

719734
if new_result_type is not None:
720-
# Validate that the model's result_type match
721-
annotation_result_type = _cached_signature(cls.__call__).return_annotation
722-
if annotation_result_type is Signature.empty:
723-
...
724-
elif isinstance(annotation_result_type, TypeVar):
725-
...
726-
elif get_origin(annotation_result_type) is Union and get_origin(new_result_type) is Union:
727-
raise TypeError(
728-
f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received {annotation_result_type}"
729-
)
730-
elif get_origin(annotation_result_type) is Union:
731-
if not any(issubclass(new_result_type, union_type) for union_type in get_args(annotation_result_type)):
732-
raise TypeError(
733-
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
734-
)
735-
elif get_origin(new_result_type) is Union:
736-
if not any(issubclass(annotation_result_type, union_type) for union_type in get_args(new_result_type)):
737-
raise TypeError(
738-
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
739-
)
740-
elif not issubclass(annotation_result_type, new_result_type):
741-
raise TypeError(
742-
f"Return type annotation {annotation_result_type} on __call__ does not match result_type {new_result_type} defined by CallableModelGenericType"
743-
)
744-
elif isclass(annotation_result_type) and issubclass(annotation_result_type, new_result_type):
745-
new_result_type = annotation_result_type
746-
747735
# Set on class
748-
cls._result_type = new_result_type
736+
cls._result_generic_type = new_result_type
749737

750738
@model_validator(mode="wrap")
751739
def _validate_callable_model_generic_type(cls, m, handler, info):
@@ -756,7 +744,8 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
756744

757745
if isinstance(m, dict):
758746
m = handler(m)
759-
cls._determine_context_result()
747+
elif isinstance(m, cls):
748+
m = handler(m)
760749

761750
# Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
762751
if not isinstance(m, CallableModel):
@@ -768,3 +757,6 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
768757
TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type)
769758

770759
return m
760+
761+
762+
CallableModelGenericType = CallableModelGeneric

ccflow/tests/test_callable.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,11 @@ def __call__(self, context: NullContext) -> Union[AResult, BResult]:
294294
return AResult(a=1)
295295

296296

297-
class UnionReturnGeneric(CallableModelGenericType[NullContext, AResult]):
297+
class UnionReturnGeneric(CallableModelGenericType[NullContext, Union[AResult, BResult]]):
298+
@property
299+
def result_type(self) -> Type[ResultType]:
300+
return AResult
301+
298302
@Flow.call
299303
def __call__(self, context: NullContext) -> Union[AResult, BResult]:
300304
# Return one branch of the Union
@@ -644,7 +648,7 @@ def test_types_generic(self):
644648
error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"
645649
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedResultAndCall)
646650

647-
error = "Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`"
651+
error = "Model __call__ signature result type cannot be a Union type without a concrete property. Please define a property 'result_type' on the model."
648652
self.assertRaisesRegex(TypeError, error, BadModelUnionReturnGeneric)
649653

650654
def test_union_return_generic(self):
@@ -653,6 +657,19 @@ def test_union_return_generic(self):
653657
self.assertIsInstance(result, AResult)
654658
self.assertEqual(result.a, 1)
655659

660+
def test_generic_validates_assignment(self):
661+
class MyCallable(CallableModelGenericType[NullContext, GenericResult[int]]):
662+
x: int = 1
663+
664+
@Flow.call
665+
def __call__(self, context: NullContext) -> GenericResult[int]:
666+
self.x = 5
667+
assert self.x == 5
668+
return GenericResult[float](value=self.x)
669+
670+
m = MyCallable()
671+
self.assertEqual(m(NullContext()).value, 5)
672+
656673

657674
class TestCallableModelDeps(TestCase):
658675
def test_basic(self):

0 commit comments

Comments
 (0)