@@ -222,6 +222,7 @@ class InferenceEndpointModelConfig:
222
222
should_reuse_existing : bool = False
223
223
add_special_tokens : bool = True
224
224
revision : str = "main"
225
+ namespace : str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
225
226
226
227
def get_dtype_args (self ) -> Dict [str , str ]:
227
228
model_dtype = self .model_dtype .lower ()
@@ -235,6 +236,15 @@ def get_dtype_args(self) -> Dict[str, str]:
235
236
return {"DTYPE" : model_dtype }
236
237
return {}
237
238
239
+ @staticmethod
240
+ def nullable_keys () -> list [str ]:
241
+ """
242
+ Returns the list of optional keys in an endpoint model configuration. By default, the code requires that all the
243
+ keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys
244
+ that are not required and can remain None.
245
+ """
246
+ return ["namespace" ]
247
+
238
248
239
249
def create_model_config (args : Namespace , accelerator : Union ["Accelerator" , None ]) -> BaseModelConfig : # noqa: C901
240
250
"""
@@ -259,76 +269,84 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
259
269
260
270
return BaseModelConfig (** args_dict )
261
271
262
- with open (args .model_config_path , "r" ) as f :
263
- config = yaml .safe_load (f )["model" ]
272
+ if args .model_config :
273
+ config = args .model_config ["model" ]
274
+ else :
275
+ with open (args .model_config_path , "r" ) as f :
276
+ config = yaml .safe_load (f )["model" ]
277
+
278
+ if config ["type" ] == "tgi" :
279
+ return TGIModelConfig (
280
+ inference_server_address = args ["instance" ]["inference_server_address" ],
281
+ inference_server_auth = args ["instance" ]["inference_server_auth" ],
282
+ )
264
283
265
- if config ["type" ] == "tgi" :
266
- return TGIModelConfig (
267
- inference_server_address = args ["instance" ]["inference_server_address" ],
268
- inference_server_auth = args ["instance" ]["inference_server_auth" ],
284
+ if config ["type" ] == "endpoint" :
285
+ reuse_existing_endpoint = config ["base_params" ]["reuse_existing" ]
286
+ complete_config_endpoint = all (
287
+ val not in [None , "" ]
288
+ for key , val in config ["instance" ].items ()
289
+ if key not in InferenceEndpointModelConfig .nullable_keys ()
290
+ )
291
+ if reuse_existing_endpoint or complete_config_endpoint :
292
+ return InferenceEndpointModelConfig (
293
+ name = config ["base_params" ]["endpoint_name" ].replace ("." , "-" ).lower (),
294
+ repository = config ["base_params" ]["model" ],
295
+ model_dtype = config ["base_params" ]["dtype" ],
296
+ revision = config ["base_params" ]["revision" ] or "main" ,
297
+ should_reuse_existing = reuse_existing_endpoint ,
298
+ accelerator = config ["instance" ]["accelerator" ],
299
+ region = config ["instance" ]["region" ],
300
+ vendor = config ["instance" ]["vendor" ],
301
+ instance_size = config ["instance" ]["instance_size" ],
302
+ instance_type = config ["instance" ]["instance_type" ],
303
+ namespace = config ["instance" ]["namespace" ],
304
+ )
305
+ return InferenceModelConfig (model = config ["base_params" ]["endpoint_name" ])
306
+
307
+ if config ["type" ] == "base" :
308
+ # Tests on the multichoice space parameters
309
+ multichoice_continuations_start_space = config ["generation" ]["multichoice_continuations_start_space" ]
310
+ no_multichoice_continuations_start_space = config ["generation" ]["no_multichoice_continuations_start_space" ]
311
+ if not multichoice_continuations_start_space and not no_multichoice_continuations_start_space :
312
+ multichoice_continuations_start_space = None
313
+ if multichoice_continuations_start_space and no_multichoice_continuations_start_space :
314
+ raise ValueError (
315
+ "You cannot force both the multichoice continuations to start with a space and not to start with a space"
269
316
)
270
317
271
- if config ["type" ] == "endpoint" :
272
- reuse_existing_endpoint = config ["base_params" ]["reuse_existing" ]
273
- complete_config_endpoint = all (val not in [None , "" ] for val in config ["instance" ].values ())
274
- if reuse_existing_endpoint or complete_config_endpoint :
275
- return InferenceEndpointModelConfig (
276
- name = config ["base_params" ]["endpoint_name" ].replace ("." , "-" ).lower (),
277
- repository = config ["base_params" ]["model" ],
278
- model_dtype = config ["base_params" ]["dtype" ],
279
- revision = config ["base_params" ]["revision" ] or "main" ,
280
- should_reuse_existing = reuse_existing_endpoint ,
281
- accelerator = config ["instance" ]["accelerator" ],
282
- region = config ["instance" ]["region" ],
283
- vendor = config ["instance" ]["vendor" ],
284
- instance_size = config ["instance" ]["instance_size" ],
285
- instance_type = config ["instance" ]["instance_type" ],
286
- )
287
- return InferenceModelConfig (model = config ["base_params" ]["endpoint_name" ])
288
-
289
- if config ["type" ] == "base" :
290
- # Tests on the multichoice space parameters
291
- multichoice_continuations_start_space = config ["generation" ]["multichoice_continuations_start_space" ]
292
- no_multichoice_continuations_start_space = config ["generation" ]["no_multichoice_continuations_start_space" ]
293
- if not multichoice_continuations_start_space and not no_multichoice_continuations_start_space :
294
- multichoice_continuations_start_space = None
295
- if multichoice_continuations_start_space and no_multichoice_continuations_start_space :
296
- raise ValueError (
297
- "You cannot force both the multichoice continuations to start with a space and not to start with a space"
298
- )
299
-
300
- # Creating optional quantization configuration
301
- if config ["base_params" ]["dtype" ] == "4bit" :
302
- quantization_config = BitsAndBytesConfig (load_in_4bit = True , bnb_4bit_compute_dtype = torch .float16 )
303
- elif config ["base_params" ]["dtype" ] == "8bit" :
304
- quantization_config = BitsAndBytesConfig (load_in_8bit = True )
305
- else :
306
- quantization_config = None
307
-
308
- # We extract the model args
309
- args_dict = {k .split ("=" )[0 ]: k .split ("=" )[1 ] for k in config ["base_params" ]["model_args" ].split ("," )}
310
-
311
- # We store the relevant other args
312
- args_dict ["base_model" ] = config ["merged_weights" ]["base_model" ]
313
- args_dict ["dtype" ] = config ["base_params" ]["dtype" ]
314
- args_dict ["accelerator" ] = accelerator
315
- args_dict ["quantization_config" ] = quantization_config
316
- args_dict ["batch_size" ] = args .override_batch_size
317
- args_dict ["multichoice_continuations_start_space" ] = multichoice_continuations_start_space
318
-
319
- # Keeping only non null params
320
- args_dict = {k : v for k , v in args_dict .items () if v is not None }
321
-
322
- if config ["merged_weights" ]["delta_weights" ]:
323
- if config ["merged_weights" ]["base_model" ] is None :
324
- raise ValueError ("You need to specify a base model when using delta weights" )
325
- return DeltaModelConfig (** args_dict )
326
- if config ["merged_weights" ]["adapter_weights" ]:
327
- if config ["merged_weights" ]["base_model" ] is None :
328
- raise ValueError ("You need to specify a base model when using adapter weights" )
329
- return AdapterModelConfig (** args_dict )
330
- if config ["merged_weights" ]["base_model" ] not in ["" , None ]:
331
- raise ValueError ("You can't specifify a base model if you are not using delta/adapter weights" )
332
- return BaseModelConfig (** args_dict )
333
-
334
- raise ValueError (f"Unknown model type in your model config file: { config ['type' ]} " )
318
+ # Creating optional quantization configuration
319
+ if config ["base_params" ]["dtype" ] == "4bit" :
320
+ quantization_config = BitsAndBytesConfig (load_in_4bit = True , bnb_4bit_compute_dtype = torch .float16 )
321
+ elif config ["base_params" ]["dtype" ] == "8bit" :
322
+ quantization_config = BitsAndBytesConfig (load_in_8bit = True )
323
+ else :
324
+ quantization_config = None
325
+
326
+ # We extract the model args
327
+ args_dict = {k .split ("=" )[0 ]: k .split ("=" )[1 ] for k in config ["base_params" ]["model_args" ].split ("," )}
328
+
329
+ # We store the relevant other args
330
+ args_dict ["base_model" ] = config ["merged_weights" ]["base_model" ]
331
+ args_dict ["dtype" ] = config ["base_params" ]["dtype" ]
332
+ args_dict ["accelerator" ] = accelerator
333
+ args_dict ["quantization_config" ] = quantization_config
334
+ args_dict ["batch_size" ] = args .override_batch_size
335
+ args_dict ["multichoice_continuations_start_space" ] = multichoice_continuations_start_space
336
+
337
+ # Keeping only non null params
338
+ args_dict = {k : v for k , v in args_dict .items () if v is not None }
339
+
340
+ if config ["merged_weights" ]["delta_weights" ]:
341
+ if config ["merged_weights" ]["base_model" ] is None :
342
+ raise ValueError ("You need to specify a base model when using delta weights" )
343
+ return DeltaModelConfig (** args_dict )
344
+ if config ["merged_weights" ]["adapter_weights" ]:
345
+ if config ["merged_weights" ]["base_model" ] is None :
346
+ raise ValueError ("You need to specify a base model when using adapter weights" )
347
+ return AdapterModelConfig (** args_dict )
348
+ if config ["merged_weights" ]["base_model" ] not in ["" , None ]:
349
+ raise ValueError ("You can't specifify a base model if you are not using delta/adapter weights" )
350
+ return BaseModelConfig (** args_dict )
351
+
352
+ raise ValueError (f"Unknown model type in your model config file: { config ['type' ]} " )
0 commit comments