4949 is_subclass ,
5050 iter_to_set_str ,
5151 lenient_check ,
52+ lenient_check_context ,
5253 NoneType ,
5354 object_path_serializer ,
5455 ParserError ,
@@ -240,6 +241,15 @@ def parse_subclass_arg(arg_string):
240241 subclass_arg_parser .set (subparser )
241242
242243
244+ @staticmethod
245+ def discard_init_args_on_class_path_change (parser , cfg_to , cfg_from ):
246+ for action in [a for a in parser ._actions if ActionTypeHint .is_subclass_typehint (a )]:
247+ val_to = cfg_to .get (action .dest )
248+ val_from = cfg_from .get (action .dest )
249+ if is_subclass_spec (val_to ) and is_subclass_spec (val_from ):
250+ discard_init_args_on_class_path_change (action , val_to , val_from )
251+
252+
243253 @staticmethod
244254 @contextmanager
245255 def subclass_arg_context (parser ):
@@ -585,7 +595,7 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
585595 # Callable
586596 elif typehint_origin in callable_origin_types or typehint in callable_origin_types :
587597 if serialize :
588- if is_class_object (val ):
598+ if is_subclass_spec (val ):
589599 val = adapt_class_type (val , True , False , sub_add_kwargs )
590600 else :
591601 val = object_path_serializer (val )
@@ -600,12 +610,11 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
600610 else :
601611 raise ImportError (f'Unexpected import object { val_obj } ' )
602612 if isinstance (val , (dict , Namespace )):
603- if not is_class_object (val ):
613+ if not is_subclass_spec (val ):
604614 raise ImportError (f'Dict must include a class_path and optionally init_args, but got { val } ' )
605- val = Namespace (val )
606- val_class = import_object (val .class_path )
615+ val_class = import_object (val ['class_path' ])
607616 if not (inspect .isclass (val_class ) and callable_instances (val_class )):
608- raise ImportError (f'" { val . class_path } " is not a callable class.' )
617+ raise ImportError (f'{ val [ " class_path" ]!r } is not a callable class.' )
609618 val ['class_path' ] = get_import_path (val_class )
610619 val = adapt_class_type (val , False , instantiate_classes , sub_add_kwargs )
611620 except (ImportError , AttributeError , ParserError ) as ex :
@@ -619,25 +628,25 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
619628 return val
620629 if serialize and isinstance (val , str ):
621630 return val
622- if not (isinstance (val , str ) or is_class_object (val )):
623- raise ValueError (f'Type { typehint } expects an str or a Dict/Namespace with a class_path entry but got "{ val } "' )
631+
632+ val_input = val
633+ val = subclass_spec_as_namespace (val )
634+ if not is_subclass_spec (val ) and prev_val and 'class_path' in prev_val :
635+ if 'init_args' in val :
636+ val ['class_path' ] = prev_val ['class_path' ]
637+ else :
638+ val = Namespace (class_path = prev_val ['class_path' ], init_args = val )
639+ if not is_subclass_spec (val ):
640+ raise ValueError (
641+ f'Type { typehint } expects: a class path (str); or a dict with a class_path entry; '
642+ f'or a dict with init_args (if class path given previously). Got "{ val_input } ".'
643+ )
644+
624645 try :
625- if isinstance (val , str ):
626- val = Namespace (class_path = val )
627- elif isinstance (val , dict ):
628- val = Namespace (val )
629646 val_class = import_object (resolve_class_path_by_name (typehint , val ['class_path' ]))
630647 if not is_subclass (val_class , typehint ):
631648 raise ValueError (f'"{ val ["class_path" ]} " is not a subclass of { typehint } ' )
632- val ['class_path' ] = class_path = get_import_path (val_class )
633- if isinstance (prev_val , Namespace ) and 'class_path' in prev_val and 'init_args' not in val :
634- prev_class_path = prev_val ['class_path' ]
635- prev_init_args = prev_val .get ('init_args' )
636- if prev_class_path != class_path and prev_init_args :
637- warnings .warn (
638- f'Changing class_path to { class_path } implies discarding init_args { prev_init_args .as_dict ()} '
639- f'defined for class_path { prev_class_path } .'
640- )
649+ val ['class_path' ] = get_import_path (val_class )
641650 val = adapt_class_type (val , serialize , instantiate_classes , sub_add_kwargs , prev_val = prev_val )
642651 except (ImportError , AttributeError , AssertionError , ParserError ) as ex :
643652 class_path = val if isinstance (val , str ) else val ['class_path' ]
@@ -647,14 +656,24 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
647656 return val
648657
649658
650- def is_class_object (val ):
659+ def is_subclass_spec (val ):
651660 is_class = isinstance (val , (dict , Namespace )) and 'class_path' in val
652661 if is_class :
653662 keys = getattr (val , '__dict__' , val ).keys ()
654663 is_class = len (set (keys )- {'class_path' , 'init_args' , 'dict_kwargs' , '__path__' }) == 0
655664 return is_class
656665
657666
667+ def subclass_spec_as_namespace (val ):
668+ if isinstance (val , str ):
669+ return Namespace (class_path = val )
670+ if isinstance (val , dict ):
671+ val = Namespace (val )
672+ if 'init_args' in val and isinstance (val ['init_args' ], dict ):
673+ val ['init_args' ] = Namespace (val ['init_args' ])
674+ return val
675+
676+
658677class NestedArg :
659678 def __init__ (self , key , val ):
660679 self .key = key
@@ -716,8 +735,34 @@ def dump_kwargs_context(kwargs):
716735 yield
717736
718737
738+ def discard_init_args_on_class_path_change (parser_or_action , prev_val , value ):
739+ if prev_val and 'init_args' in prev_val and prev_val ['class_path' ] != value .class_path :
740+ parser = parser_or_action
741+ if isinstance (parser_or_action , ActionTypeHint ):
742+ sub_add_kwargs = getattr (parser_or_action , 'sub_add_kwargs' , {})
743+ parser = ActionTypeHint .get_class_parser (value .class_path , sub_add_kwargs )
744+ prev_val = subclass_spec_as_namespace (prev_val )
745+ del_args = {}
746+ for key , val in list (prev_val .init_args .__dict__ .items ()):
747+ action = _find_action (parser , key )
748+ if action :
749+ with lenient_check_context (lenient = False ):
750+ try :
751+ parser ._check_value_key (action , val , key , Namespace ())
752+ except Exception :
753+ action = None
754+ if not action :
755+ del_args [key ] = prev_val .init_args .pop (key )
756+ if del_args :
757+ warnings .warn (
758+ f'Due to class_path change from { prev_val .class_path !r} to { value .class_path !r} , '
759+ f'discarding init_args: { del_args } .'
760+ )
761+
762+
719763def adapt_class_type (value , serialize , instantiate_classes , sub_add_kwargs , prev_val = None ):
720- val_class = import_object (value ['class_path' ])
764+ value = subclass_spec_as_namespace (value )
765+ val_class = import_object (value .class_path )
721766 parser = ActionTypeHint .get_class_parser (val_class , sub_add_kwargs )
722767
723768 # No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream.
@@ -736,18 +781,18 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
736781
737782 break
738783
739- unresolved = value .pop ('dict_kwargs' , {})
784+ discard_init_args_on_class_path_change (parser , prev_val , value )
785+
786+ dict_kwargs = value .pop ('dict_kwargs' , {})
740787 init_args = value .get ('init_args' , Namespace ())
741- if isinstance (init_args , dict ):
742- value ['init_args' ] = init_args = Namespace (init_args )
743788
744789 if instantiate_classes :
745790 init_args = parser .instantiate_classes (init_args )
746791 if not sub_add_kwargs .get ('instantiate' , True ):
747792 if init_args :
748793 value ['init_args' ] = init_args
749794 return value
750- return val_class (** {** init_args , ** unresolved })
795+ return val_class (** {** init_args , ** dict_kwargs })
751796
752797 if isinstance (init_args , NestedArg ):
753798 value ['init_args' ] = parser .parse_args (
@@ -761,29 +806,29 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev
761806 if init_args :
762807 value ['init_args' ] = load_value (parser .dump (init_args , ** dump_kwargs .get ()))
763808 else :
764- if isinstance (unresolved , str ):
765- unresolved = load_value (unresolved )
766- if isinstance (unresolved , (dict , Namespace )):
767- for key in list (unresolved .keys ()):
809+ if isinstance (dict_kwargs , str ):
810+ dict_kwargs = load_value (dict_kwargs )
811+ if isinstance (dict_kwargs , (dict , Namespace )):
812+ for key in list (dict_kwargs .keys ()):
768813 if _find_action (parser , key ):
769- init_args [key ] = unresolved .pop (key )
770- if isinstance (unresolved , Namespace ):
771- unresolved = dict (unresolved )
772- elif unresolved :
773- init_args ['dict_kwargs' ] = unresolved
774- unresolved = None
814+ init_args [key ] = dict_kwargs .pop (key )
815+ if isinstance (dict_kwargs , Namespace ):
816+ dict_kwargs = dict (dict_kwargs )
817+ elif dict_kwargs :
818+ init_args ['dict_kwargs' ] = dict_kwargs
819+ dict_kwargs = None
775820 init_args = parser .parse_object (init_args , defaults = sub_defaults .get ())
776821 if init_args :
777822 value ['init_args' ] = init_args
778- if unresolved :
779- value ['dict_kwargs' ] = unresolved
823+ if dict_kwargs :
824+ value ['dict_kwargs' ] = dict_kwargs
780825 return value
781826
782827
783828def not_append_diff (val1 , val2 ):
784829 if isinstance (val1 , list ) and isinstance (val2 , list ):
785- val1 = [x .get ('class_path' ) if is_class_object (x ) else x for x in val1 ]
786- val2 = [x .get ('class_path' ) if is_class_object (x ) else x for x in val2 ]
830+ val1 = [x .get ('class_path' ) if is_subclass_spec (x ) else x for x in val1 ]
831+ val2 = [x .get ('class_path' ) if is_subclass_spec (x ) else x for x in val2 ]
787832 return val1 != val2
788833
789834
0 commit comments