@@ -744,12 +744,12 @@ def _get_regional_property(
744744
745745
746746class  JumpStartBenchmarkStat (JumpStartDataHolderType ):
747-     """Data class JumpStart benchmark stats .""" 
747+     """Data class JumpStart benchmark stat .""" 
748748
749749    __slots__  =  ["name" , "value" , "unit" ]
750750
751751    def  __init__ (self , spec : Dict [str , Any ]):
752-         """Initializes a JumpStartBenchmarkStat object 
752+         """Initializes a JumpStartBenchmarkStat object.  
753753
754754        Args: 
755755            spec (Dict[str, Any]): Dictionary representation of benchmark stat. 
@@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
858858        "model_subscription_link" ,
859859    ]
860860
861-     def  __init__ (self , fields : Optional [ Dict [str , Any ] ]):
861+     def  __init__ (self , fields : Dict [str , Any ]):
862862        """Initializes a JumpStartMetadataFields object. 
863863
864864        Args: 
@@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
877877        self .version : str  =  json_obj .get ("version" )
878878        self .min_sdk_version : str  =  json_obj .get ("min_sdk_version" )
879879        self .incremental_training_supported : bool  =  bool (
880-             json_obj .get ("incremental_training_supported" )
880+             json_obj .get ("incremental_training_supported" ,  False )
881881        )
882882        self .hosting_ecr_specs : Optional [JumpStartECRSpecs ] =  (
883883            JumpStartECRSpecs (json_obj ["hosting_ecr_specs" ])
@@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
10381038
10391039    __slots__  =  slots  +  JumpStartMetadataBaseFields .__slots__ 
10401040
1041-     def  __init__ (   # pylint: disable=super-init-not-called 
1041+     def  __init__ (
10421042        self ,
10431043        component_name : str ,
10441044        component : Optional [Dict [str , Any ]],
@@ -1049,7 +1049,10 @@ def __init__(  # pylint: disable=super-init-not-called
10491049            component_name (str): Name of the component. 
10501050            component (Dict[str, Any]): 
10511051                Dictionary representation of the config component. 
1052+         Raises: 
1053+             ValueError: If the component field is invalid. 
10521054        """ 
1055+         super ().__init__ (component )
10531056        self .component_name  =  component_name 
10541057        self .from_json (component )
10551058
@@ -1080,7 +1083,7 @@ def __init__(
10801083        self ,
10811084        base_fields : Dict [str , Any ],
10821085        config_components : Dict [str , JumpStartConfigComponent ],
1083-         benchmark_metrics : Dict [str , JumpStartBenchmarkStat ],
1086+         benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ],
10841087    ):
10851088        """Initializes a JumpStartMetadataConfig object from its json representation. 
10861089
@@ -1089,12 +1092,12 @@ def __init__(
10891092                The default base fields that are used to construct the final resolved config. 
10901093            config_components (Dict[str, JumpStartConfigComponent]): 
10911094                The list of components that are used to construct the resolved config. 
1092-             benchmark_metrics (Dict[str, JumpStartBenchmarkStat]): 
1095+             benchmark_metrics (Dict[str, List[ JumpStartBenchmarkStat] ]): 
10931096                The dictionary of benchmark metrics with name being the key. 
10941097        """ 
10951098        self .base_fields  =  base_fields 
10961099        self .config_components : Dict [str , JumpStartConfigComponent ] =  config_components 
1097-         self .benchmark_metrics : Dict [str , JumpStartBenchmarkStat ] =  benchmark_metrics 
1100+         self .benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ] =  benchmark_metrics 
10981101        self .resolved_metadata_config : Optional [Dict [str , Any ]] =  None 
10991102
11001103    def  to_json (self ) ->  Dict [str , Any ]:
@@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]:
11041107
11051108    @property  
11061109    def  resolved_config (self ) ->  Dict [str , Any ]:
1107-         """Returns the final config that is resolved from the list of  components. 
1110+         """Returns the final config that is resolved from the components map . 
11081111
11091112        Construct the final config by applying the list of configs from list index, 
11101113        and apply to the base default fields in the current model specs. 
@@ -1139,7 +1142,7 @@ def __init__(
11391142
11401143        Args: 
11411144            configs (Dict[str, JumpStartMetadataConfig]): 
1142-                 List  of configs that the current model has . 
1145+                 The map  of JumpStartMetadataConfig object, with config name being the key . 
11431146            config_rankings (JumpStartConfigRanking): 
11441147                Config ranking class represents the ranking of the configs in the model. 
11451148            scope (JumpStartScriptScope): 
@@ -1158,19 +1161,30 @@ def get_top_config_from_ranking(
11581161        self ,
11591162        ranking_name : str  =  JumpStartConfigRankingName .DEFAULT ,
11601163        instance_type : Optional [str ] =  None ,
1161-     ) ->  JumpStartMetadataConfig :
1162-         """Gets the best the config based on config ranking.""" 
1164+     ) ->  Optional [JumpStartMetadataConfig ]:
1165+         """Gets the best the config based on config ranking. 
1166+ 
1167+         Args: 
1168+             ranking_name (str): 
1169+                 The ranking name that config priority is based on. 
1170+             instance_type (Optional[str]): 
1171+                 The instance type which the config selection is based on. 
1172+ 
1173+         Raises: 
1174+             ValueError: If the config exists but missing config ranking. 
1175+             NotImplementedError: If the scope is unrecognized. 
1176+         """ 
11631177        if  self .configs  and  (
11641178            not  self .config_rankings  or  not  self .config_rankings .get (ranking_name )
11651179        ):
1166-             raise  ValueError ("Config exists but missing config ranking." )
1180+             raise  ValueError (f "Config exists but missing config ranking  { ranking_name } 
11671181
11681182        if  self .scope  ==  JumpStartScriptScope .INFERENCE :
11691183            instance_type_attribute  =  "supported_inference_instance_types" 
11701184        elif  self .scope  ==  JumpStartScriptScope .TRAINING :
11711185            instance_type_attribute  =  "supported_training_instance_types" 
11721186        else :
1173-             raise  ValueError (f"Unknown script scope { self .scope }  )
1187+             raise  NotImplementedError (f"Unknown script scope { self .scope }  )
11741188
11751189        rankings  =  self .config_rankings .get (ranking_name )
11761190        for  config_name  in  rankings .rankings :
@@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields):
11981212
11991213    __slots__  =  JumpStartMetadataBaseFields .__slots__  +  slots 
12001214
1201-     def  __init__ (self , spec : Dict [str , Any ]):   # pylint: disable=super-init-not-called 
1215+     def  __init__ (self , spec : Dict [str , Any ]):
12021216        """Initializes a JumpStartModelSpecs object from its json representation. 
12031217
12041218        Args: 
12051219            spec (Dict[str, Any]): Dictionary representation of spec. 
12061220        """ 
1221+         super ().__init__ (spec )
12071222        self .from_json (spec )
12081223        if  self .inference_configs  and  self .inference_configs .get_top_config_from_ranking ():
12091224            super ().from_json (self .inference_configs .get_top_config_from_ranking ().resolved_config )
@@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12451260                    ),
12461261                    (
12471262                        {
1248-                             stat_name : JumpStartBenchmarkStat (stat )
1249-                             for  stat_name , stat  in  config .get ("benchmark_metrics" ).items ()
1263+                             stat_name : [ JumpStartBenchmarkStat (stat )  for   stat   in   stats ] 
1264+                             for  stat_name , stats  in  config .get ("benchmark_metrics" ).items ()
12501265                        }
12511266                        if  config  and  config .get ("benchmark_metrics" )
12521267                        else  None 
@@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
12971312                        ),
12981313                        (
12991314                            {
1300-                                 stat_name : JumpStartBenchmarkStat (stat )
1301-                                 for  stat_name , stat  in  config .get ("benchmark_metrics" ).items ()
1315+                                 stat_name : [ JumpStartBenchmarkStat (stat )  for   stat   in   stats ] 
1316+                                 for  stat_name , stats  in  config .get ("benchmark_metrics" ).items ()
13021317                            }
13031318                            if  config  and  config .get ("benchmark_metrics" )
13041319                            else  None 
@@ -1330,13 +1345,26 @@ def set_config(
13301345            config_name (str): Name of the config. 
13311346            scope (JumpStartScriptScope, optional): 
13321347                Scope of the config. Defaults to JumpStartScriptScope.INFERENCE. 
1348+ 
1349+         Raises: 
1350+             ValueError: If the scope is not supported, or cannot find config name. 
13331351        """ 
13341352        if  scope  ==  JumpStartScriptScope .INFERENCE :
1335-             super (). from_json ( self .inference_configs . configs [ config_name ]. resolved_config ) 
1353+             metadata_configs   =   self .inference_configs 
13361354        elif  scope  ==  JumpStartScriptScope .TRAINING  and  self .training_supported :
1337-             super (). from_json ( self .training_configs . configs [ config_name ]. resolved_config ) 
1355+             metadata_configs   =   self .training_configs 
13381356        else :
1339-             raise  ValueError (f"Unknown Jumpstart Script scope { scope }  )
1357+             raise  ValueError (f"Unknown Jumpstart script scope { scope }  )
1358+ 
1359+         config_object  =  metadata_configs .configs .get (config_name )
1360+         if  not  config_object :
1361+             error_msg  =  f"Cannot find Jumpstart config name { config_name }  
1362+             config_names  =  list (metadata_configs .configs .keys ())
1363+             if  config_names :
1364+                 error_msg  +=  f"List of config names that is supported by the model: { config_names }  
1365+             raise  ValueError (error_msg )
1366+ 
1367+         super ().from_json (config_object .resolved_config )
13401368
13411369    def  supports_prepacked_inference (self ) ->  bool :
13421370        """Returns True if the model has a prepacked inference artifact.""" 
0 commit comments