@@ -586,27 +586,15 @@ def _normalize_arch(
586586
587587 return architecture
588588
589- def _normalize_archs (
590- self ,
591- architectures : list [str ],
592- model_config : ModelConfig ,
593- ) -> list [str ]:
594- if not architectures :
595- logger .warning ("No model architectures are specified" )
596-
597- return [
598- self ._normalize_arch (arch , model_config ) for arch in architectures
599- ]
600-
601589 def inspect_model_cls (
602590 self ,
603591 architectures : Union [str , list [str ]],
604592 model_config : ModelConfig ,
605593 ) -> tuple [_ModelInfo , str ]:
606594 if isinstance (architectures , str ):
607595 architectures = [architectures ]
608-
609- normalized_archs = self . _normalize_archs ( architectures , model_config )
596+ if not architectures :
597+ raise ValueError ( "No model architectures are specified" )
610598
611599 # Require transformers impl
612600 if model_config .model_impl == ModelImpl .TRANSFORMERS :
@@ -617,13 +605,26 @@ def inspect_model_cls(
617605 if model_info is not None :
618606 return (model_info , arch )
619607
620- for arch , normalized_arch in zip (architectures , normalized_archs ):
608+ # Fallback to transformers impl (after resolving convert_type)
609+ if (all (arch not in self .models for arch in architectures )
610+ and model_config .model_impl == ModelImpl .AUTO
611+ and getattr (model_config , "convert_type" , "none" ) == "none" ):
612+ arch = self ._try_resolve_transformers (architectures [0 ],
613+ model_config )
614+ if arch is not None :
615+ model_info = self ._try_inspect_model_cls (arch )
616+ if model_info is not None :
617+ return (model_info , arch )
618+
619+ for arch in architectures :
620+ normalized_arch = self ._normalize_arch (arch , model_config )
621621 model_info = self ._try_inspect_model_cls (normalized_arch )
622622 if model_info is not None :
623623 return (model_info , arch )
624624
625- # Fallback to transformers impl
626- if model_config .model_impl in (ModelImpl .AUTO , ModelImpl .TRANSFORMERS ):
625+ # Fallback to transformers impl (before resolving runner_type)
626+ if (all (arch not in self .models for arch in architectures )
627+ and model_config .model_impl == ModelImpl .AUTO ):
627628 arch = self ._try_resolve_transformers (architectures [0 ],
628629 model_config )
629630 if arch is not None :
@@ -640,8 +641,8 @@ def resolve_model_cls(
640641 ) -> tuple [type [nn .Module ], str ]:
641642 if isinstance (architectures , str ):
642643 architectures = [architectures ]
643-
644- normalized_archs = self . _normalize_archs ( architectures , model_config )
644+ if not architectures :
645+ raise ValueError ( "No model architectures are specified" )
645646
646647 # Require transformers impl
647648 if model_config .model_impl == ModelImpl .TRANSFORMERS :
@@ -652,13 +653,26 @@ def resolve_model_cls(
652653 if model_cls is not None :
653654 return (model_cls , arch )
654655
655- for arch , normalized_arch in zip (architectures , normalized_archs ):
656+ # Fallback to transformers impl (after resolving convert_type)
657+ if (all (arch not in self .models for arch in architectures )
658+ and model_config .model_impl == ModelImpl .AUTO
659+ and getattr (model_config , "convert_type" , "none" ) == "none" ):
660+ arch = self ._try_resolve_transformers (architectures [0 ],
661+ model_config )
662+ if arch is not None :
663+ model_cls = self ._try_load_model_cls (arch )
664+ if model_cls is not None :
665+ return (model_cls , arch )
666+
667+ for arch in architectures :
668+ normalized_arch = self ._normalize_arch (arch , model_config )
656669 model_cls = self ._try_load_model_cls (normalized_arch )
657670 if model_cls is not None :
658671 return (model_cls , arch )
659672
660- # Fallback to transformers impl
661- if model_config .model_impl in (ModelImpl .AUTO , ModelImpl .TRANSFORMERS ):
673+ # Fallback to transformers impl (before resolving runner_type)
674+ if (all (arch not in self .models for arch in architectures )
675+ and model_config .model_impl == ModelImpl .AUTO ):
662676 arch = self ._try_resolve_transformers (architectures [0 ],
663677 model_config )
664678 if arch is not None :
0 commit comments