@@ -79,6 +79,50 @@ class _CallableModel(BaseModel, abc.ABC):
7979 )
8080 meta : MetaData = Field (default_factory = MetaData )
8181
82+ @classmethod
83+ def _check_context_type (cls , context_type ):
84+ type_call_arg = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
85+
86+ # If optional type, extract inner type
87+ if get_origin (type_call_arg ) is Optional or (get_origin (type_call_arg ) is Union and type (None ) in get_args (type_call_arg )):
88+ type_call_arg = [t for t in get_args (type_call_arg ) if t is not type (None )][0 ]
89+
90+ if (
91+ not isinstance (type_call_arg , TypeVar )
92+ and type_call_arg is not Signature .empty
93+ and (not isclass (type_call_arg ) or not issubclass (type_call_arg , context_type ))
94+ and (not isclass (context_type ) or not issubclass (context_type , type_call_arg ))
95+ ):
96+ err_msg_type_mismatch = f"The context_type { context_type } must match the type of the context accepted by __call__ { type_call_arg } "
97+ raise ValueError (err_msg_type_mismatch )
98+
99+ @classmethod
100+ def _check_result_type (cls , result_type ):
101+ type_call_return = _cached_signature (cls .__call__ ).return_annotation
102+
103+ # If union, check all types
104+ if get_origin (type_call_return ) is Union and get_args (type_call_return ):
105+ types_call_return = [t for t in get_args (type_call_return ) if t is not type (None )]
106+ else :
107+ types_call_return = [type_call_return ]
108+
109+ all_bad = True
110+ for type_call_return in types_call_return :
111+ if (
112+ not isinstance (type_call_return , TypeVar )
113+ and type_call_return is not Signature .empty
114+ and (not isclass (type_call_return ) or not issubclass (type_call_return , result_type ))
115+ and (not isclass (result_type ) or not issubclass (result_type , type_call_return ))
116+ ):
117+ # Don't invert logic so that we match context above
118+ pass
119+ else :
120+ all_bad = False
121+
122+ if all_bad :
123+ err_msg_type_mismatch = f"The result_type { result_type } must match the return type of __call__ { type_call_return } "
124+ raise ValueError (err_msg_type_mismatch )
125+
82126 @model_validator (mode = "after" )
83127 def _check_signature (self ):
84128 sig_call = _cached_signature (self .__class__ .__call__ )
@@ -98,50 +142,12 @@ def _check_signature(self):
98142 )
99143 raise ValueError (err_msg_type_mismatch )
100144
101- # If context_type or result_type are overridden, ensure they match the signature
145+ # If context_type or result_type are overridden or
146+ # come from generic type, ensure they match the signature
102147 if hasattr (self , "context_type" ):
103- type_call_arg = _cached_signature (self .__class__ .__call__ ).parameters ["context" ].annotation
104-
105- # If optional type, extract inner type
106- if get_origin (type_call_arg ) is Optional or (get_origin (type_call_arg ) is Union and type (None ) in get_args (type_call_arg )):
107- type_call_arg = [t for t in get_args (type_call_arg ) if t is not type (None )][0 ]
108-
109- if (
110- not isinstance (type_call_arg , TypeVar )
111- and type_call_arg is not Signature .empty
112- and (not isclass (type_call_arg ) or not issubclass (type_call_arg , self .context_type ))
113- and (not isclass (self .context_type ) or not issubclass (self .context_type , type_call_arg ))
114- ):
115- err_msg_type_mismatch = (
116- f"The context_type { self .context_type } must match the type of the context accepted by __call__ { type_call_arg } "
117- )
118- raise ValueError (err_msg_type_mismatch )
119-
148+ self ._check_context_type (self .context_type )
120149 if hasattr (self , "result_type" ):
121- type_call_return = _cached_signature (self .__class__ .__call__ ).return_annotation
122-
123- # If union, check all types
124- if get_origin (type_call_return ) is Union and get_args (type_call_return ):
125- types_call_return = [t for t in get_args (type_call_return ) if t is not type (None )]
126- else :
127- types_call_return = [type_call_return ]
128-
129- all_bad = True
130- for type_call_return in types_call_return :
131- if (
132- not isinstance (type_call_return , TypeVar )
133- and type_call_return is not Signature .empty
134- and (not isclass (type_call_return ) or not issubclass (type_call_return , self .result_type ))
135- and (not isclass (self .result_type ) or not issubclass (self .result_type , type_call_return ))
136- ):
137- # Don't invert logic so that we match context above
138- pass
139- else :
140- all_bad = False
141-
142- if all_bad :
143- err_msg_type_mismatch = f"The result_type { self .result_type } must match the return type of __call__ { type_call_return } "
144- raise ValueError (err_msg_type_mismatch )
150+ self ._check_result_type (self .result_type )
145151
146152 return self
147153
@@ -548,13 +554,15 @@ def context_type(self) -> Type[ContextType]:
548554 """
549555 typ = _cached_signature (self .__class__ .__call__ ).parameters ["context" ].annotation
550556 if typ is Signature .empty :
551- raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
552-
553- self ._check_context_type (typ )
554- return typ
557+ if isinstance (self , CallableModelGenericType ) and hasattr (self , "context_generic_type" ):
558+ typ = self .context_generic_type
559+ else :
560+ raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
561+ elif isinstance (self , CallableModelGenericType ) and hasattr (self , "context_generic_type" ) and not issubclass (typ , self .context_generic_type ):
562+ raise TypeError (
563+ f"Context type annotation { typ } on __call__ does not match context_type { self .context_generic_type } defined by CallableModelGenericType"
564+ )
555565
556- @staticmethod
557- def _check_context_type (typ ):
558566 # If optional type, extract inner type
559567 if get_origin (typ ) is Optional or (get_origin (typ ) is Union and type (None ) in get_args (typ )):
560568 type_to_check = [t for t in get_args (typ ) if t is not type (None )][0 ]
@@ -565,6 +573,8 @@ def _check_context_type(typ):
565573 if not isclass (type_to_check ) or not issubclass (type_to_check , ContextBase ):
566574 raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { type_to_check } ." )
567575
576+ return typ
577+
568578 @property
569579 def result_type (self ) -> Type [ResultType ]:
570580 """Return the result type for the model.
@@ -574,13 +584,29 @@ def result_type(self) -> Type[ResultType]:
574584 """
575585 typ = _cached_signature (self .__class__ .__call__ ).return_annotation
576586 if typ is Signature .empty :
577- raise TypeError ("Must either define a return type annotation on __call__ or implement 'result_type'" )
578-
579- self ._check_result_type (typ )
580- return typ
587+ if isinstance (self , CallableModelGenericType ) and hasattr (self , "result_generic_type" ):
588+ typ = self .result_generic_type
589+ else :
590+ raise TypeError ("Must either define a return type annotation on __call__ or implement 'result_type'" )
591+ elif isinstance (self , CallableModelGenericType ) and hasattr (self , "result_generic_type" ):
592+ if get_origin (typ ) is Union and get_origin (self .result_generic_type ) is Union :
593+ if set (get_args (typ )) != set (get_args (self .result_generic_type )):
594+ raise TypeError (
595+ f"Return type annotation { typ } on __call__ does not match result_type { self .result_generic_type } defined by CallableModelGenericType"
596+ )
597+ elif get_origin (typ ) is Union :
598+ raise NotImplementedError (
599+ "Return type annotation on __call__ is a Union, but result_type defined by CallableModelGenericType is not a Union. This case is not yet supported."
600+ )
601+ elif get_origin (self .result_generic_type ) is Union :
602+ raise NotImplementedError (
603+ "Return type annotation on __call__ is not a Union, but result_type defined by CallableModelGenericType is a Union. This case is not yet supported."
604+ )
605+ elif not issubclass (typ , self .result_generic_type ):
606+ raise TypeError (
607+ f"Return type annotation { typ } on __call__ does not match result_type { self .result_generic_type } defined by CallableModelGenericType"
608+ )
581609
582- @staticmethod
583- def _check_result_type (typ ):
584610 # If union type, extract inner type
585611 if get_origin (typ ) is Union :
586612 raise TypeError (
@@ -590,6 +616,7 @@ def _check_result_type(typ):
590616 # Ensure subclass of ResultBase
591617 if not isclass (typ ) or not issubclass (typ , ResultBase ):
592618 raise TypeError (f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received { typ } ." )
619+ return typ
593620
594621 @Flow .deps
595622 def __deps__ (
@@ -625,51 +652,56 @@ def result_type(self) -> Type[ResultType]:
625652 return self .model .result_type
626653
627654
628- class CallableModelGenericType (CallableModel , Generic [ContextType , ResultType ]):
655+ class CallableModelGeneric (CallableModel , Generic [ContextType , ResultType ]):
629656 """Special type of callable model that provides context and result via
630657 a generic type instead of annotations on __call__.
631658 """
632659
633- _context_type : ClassVar [Type [ContextType ]]
634- _result_type : ClassVar [Type [ResultType ]]
660+ _context_generic_type : ClassVar [Type [ContextType ]]
661+ _result_generic_type : ClassVar [Type [ResultType ]]
635662
636663 @property
637- def context_type (self ) -> Type [ContextType ]:
638- return self ._context_type
664+ def context_generic_type (self ) -> Type [ContextType ]:
665+ return self ._context_generic_type
639666
640667 @property
641- def result_type (self ) -> Type [ResultType ]:
642- return self ._result_type
668+ def result_generic_type (self ) -> Type [ResultType ]:
669+ return self ._result_generic_type
643670
644671 def __setstate__ (self , state ):
645- super ().__setstate__ (state )
646672 self ._determine_context_result ()
673+ super ().__setstate__ (state )
674+
675+ @classmethod
676+ def __pydantic_init_subclass__ (cls , ** kwargs ):
677+ super ().__pydantic_init_subclass__ (** kwargs )
678+ cls ._determine_context_result ()
647679
648680 @classmethod
649681 def _determine_context_result (cls ):
650682 # Extract the generic types from the class definition
651- if not hasattr (cls , "_context_type " ) or not hasattr (cls , "_result_type " ):
683+ if not hasattr (cls , "_context_generic_type " ) or not hasattr (cls , "_result_generic_type " ):
652684 new_context_type = None
653685 new_result_type = None
654686
655687 for base in cls .__mro__ :
656- if issubclass (base , CallableModelGenericType ):
688+ if issubclass (base , CallableModelGeneric ):
657689 # Found the generic base class, it should
658690 # have either generic parameters or context/result
659- if new_context_type is None and hasattr (base , "_context_type " ) and issubclass (base ._context_type , ContextBase ):
660- new_context_type = base ._context_type
691+ if new_context_type is None and hasattr (base , "_context_generic_type " ) and issubclass (base ._context_generic_type , ContextBase ):
692+ new_context_type = base ._context_generic_type
661693 if (
662694 new_result_type is None
663- and hasattr (base , "_result_type " )
695+ and hasattr (base , "_result_generic_type " )
664696 and (
665- issubclass (base ._result_type , ResultBase )
697+ issubclass (base ._result_generic_type , ResultBase )
666698 or (
667- get_origin (base ._result_type ) is Union
668- and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (base ._result_type ))
699+ get_origin (base ._result_generic_type ) is Union
700+ and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (base ._result_generic_type ))
669701 )
670702 )
671703 ):
672- new_result_type = base ._result_type
704+ new_result_type = base ._result_generic_type
673705 if base .__pydantic_generic_metadata__ ["args" ]:
674706 if len (base .__pydantic_generic_metadata__ ["args" ]) >= 2 :
675707 # Assume order is ContextType, ResultType
@@ -696,56 +728,12 @@ def _determine_context_result(cls):
696728 break
697729
698730 if new_context_type is not None :
699- # Validate that the model's context_type match
700- annotation_context_type = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
701- if get_origin (annotation_context_type ) is Optional or (
702- get_origin (annotation_context_type ) is Union and type (None ) in get_args (annotation_context_type )
703- ):
704- annotation_context_type = [t for t in get_args (annotation_context_type ) if t is not type (None )][0 ]
705- if (
706- annotation_context_type is not Signature .empty
707- and not isinstance (annotation_context_type , TypeVar )
708- and not issubclass (annotation_context_type , new_context_type )
709- ):
710- raise TypeError (
711- f"Context type annotation { annotation_context_type } on __call__ does not match context_type { new_context_type } defined by CallableModelGenericType"
712- )
713- elif isclass (annotation_context_type ) and issubclass (annotation_context_type , new_context_type ):
714- new_context_type = annotation_context_type
715-
716731 # Set on class
717- cls ._context_type = new_context_type
732+ cls ._context_generic_type = new_context_type
718733
719734 if new_result_type is not None :
720- # Validate that the model's result_type match
721- annotation_result_type = _cached_signature (cls .__call__ ).return_annotation
722- if annotation_result_type is Signature .empty :
723- ...
724- elif isinstance (annotation_result_type , TypeVar ):
725- ...
726- elif get_origin (annotation_result_type ) is Union and get_origin (new_result_type ) is Union :
727- raise TypeError (
728- f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received { annotation_result_type } "
729- )
730- elif get_origin (annotation_result_type ) is Union :
731- if not any (issubclass (new_result_type , union_type ) for union_type in get_args (annotation_result_type )):
732- raise TypeError (
733- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
734- )
735- elif get_origin (new_result_type ) is Union :
736- if not any (issubclass (annotation_result_type , union_type ) for union_type in get_args (new_result_type )):
737- raise TypeError (
738- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
739- )
740- elif not issubclass (annotation_result_type , new_result_type ):
741- raise TypeError (
742- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
743- )
744- elif isclass (annotation_result_type ) and issubclass (annotation_result_type , new_result_type ):
745- new_result_type = annotation_result_type
746-
747735 # Set on class
748- cls ._result_type = new_result_type
736+ cls ._result_generic_type = new_result_type
749737
750738 @model_validator (mode = "wrap" )
751739 def _validate_callable_model_generic_type (cls , m , handler , info ):
@@ -756,7 +744,8 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
756744
757745 if isinstance (m , dict ):
758746 m = handler (m )
759- cls ._determine_context_result ()
747+ elif isinstance (m , cls ):
748+ m = handler (m )
760749
761750 # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
762751 if not isinstance (m , CallableModel ):
@@ -768,3 +757,6 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
768757 TypeAdapter (Type [subtypes [1 ]]).validate_python (m .result_type )
769758
770759 return m
760+
761+
762+ CallableModelGenericType = CallableModelGeneric
0 commit comments