Skip to content

Commit 7d720c1

Browse files
Added Namespace parameter for InferenceEndpoints, added option for passing model config directly (#147)
1] I added the `namespace` parameter to the endpoint model config, to allow deploying endpoints from an organization that isn't your private user. - Added the namespace property to the class - Added a static method `nullable_keys()` so that the check for a complete config when creating the endpoint can ignore those. 2] I added the option to specify the model config directly from python code, with a simple check to see if it was given. (Note: I unindented everything after the `yaml.safe_load`, since after the yaml was loaded there was no need to remain inside the scope of accessing the file). --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent caf7112 commit 7d720c1

File tree

3 files changed

+93
-71
lines changed

3 files changed

+93
-71
lines changed

examples/model_configs/endpoint_model.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ model:
1414
instance_type: "g5.2xlarge"
1515
framework: "pytorch"
1616
endpoint_type: "protected"
17+
namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace
1718
generation:
1819
add_special_tokens: true

src/lighteval/models/endpoint_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ def __init__(
6666
self.reuse_existing = getattr(config, "should_reuse_existing", True)
6767
if isinstance(config, InferenceEndpointModelConfig):
6868
if config.should_reuse_existing:
69-
self.endpoint = get_inference_endpoint(name=config.name, token=env_config.token)
69+
self.endpoint = get_inference_endpoint(
70+
name=config.name, token=env_config.token, namespace=config.namespace
71+
)
7072
else:
7173
self.endpoint: InferenceEndpoint = create_inference_endpoint(
7274
name=config.name,
75+
namespace=config.namespace,
7376
repository=config.repository,
7477
revision=config.revision,
7578
framework=config.framework,

src/lighteval/models/model_config.py

Lines changed: 88 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ class InferenceEndpointModelConfig:
222222
should_reuse_existing: bool = False
223223
add_special_tokens: bool = True
224224
revision: str = "main"
225+
namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
225226

226227
def get_dtype_args(self) -> Dict[str, str]:
227228
model_dtype = self.model_dtype.lower()
@@ -235,6 +236,15 @@ def get_dtype_args(self) -> Dict[str, str]:
235236
return {"DTYPE": model_dtype}
236237
return {}
237238

239+
@staticmethod
240+
def nullable_keys() -> list[str]:
241+
"""
242+
Returns the list of optional keys in an endpoint model configuration. By default, the code requires that all the
243+
keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys
244+
that are not required and can remain None.
245+
"""
246+
return ["namespace"]
247+
238248

239249
def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
240250
"""
@@ -259,76 +269,84 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
259269

260270
return BaseModelConfig(**args_dict)
261271

262-
with open(args.model_config_path, "r") as f:
263-
config = yaml.safe_load(f)["model"]
272+
if args.model_config:
273+
config = args.model_config["model"]
274+
else:
275+
with open(args.model_config_path, "r") as f:
276+
config = yaml.safe_load(f)["model"]
277+
278+
if config["type"] == "tgi":
279+
return TGIModelConfig(
280+
inference_server_address=args["instance"]["inference_server_address"],
281+
inference_server_auth=args["instance"]["inference_server_auth"],
282+
)
264283

265-
if config["type"] == "tgi":
266-
return TGIModelConfig(
267-
inference_server_address=args["instance"]["inference_server_address"],
268-
inference_server_auth=args["instance"]["inference_server_auth"],
284+
if config["type"] == "endpoint":
285+
reuse_existing_endpoint = config["base_params"]["reuse_existing"]
286+
complete_config_endpoint = all(
287+
val not in [None, ""]
288+
for key, val in config["instance"].items()
289+
if key not in InferenceEndpointModelConfig.nullable_keys()
290+
)
291+
if reuse_existing_endpoint or complete_config_endpoint:
292+
return InferenceEndpointModelConfig(
293+
name=config["base_params"]["endpoint_name"].replace(".", "-").lower(),
294+
repository=config["base_params"]["model"],
295+
model_dtype=config["base_params"]["dtype"],
296+
revision=config["base_params"]["revision"] or "main",
297+
should_reuse_existing=reuse_existing_endpoint,
298+
accelerator=config["instance"]["accelerator"],
299+
region=config["instance"]["region"],
300+
vendor=config["instance"]["vendor"],
301+
instance_size=config["instance"]["instance_size"],
302+
instance_type=config["instance"]["instance_type"],
303+
namespace=config["instance"]["namespace"],
304+
)
305+
return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
306+
307+
if config["type"] == "base":
308+
# Tests on the multichoice space parameters
309+
multichoice_continuations_start_space = config["generation"]["multichoice_continuations_start_space"]
310+
no_multichoice_continuations_start_space = config["generation"]["no_multichoice_continuations_start_space"]
311+
if not multichoice_continuations_start_space and not no_multichoice_continuations_start_space:
312+
multichoice_continuations_start_space = None
313+
if multichoice_continuations_start_space and no_multichoice_continuations_start_space:
314+
raise ValueError(
315+
"You cannot force both the multichoice continuations to start with a space and not to start with a space"
269316
)
270317

271-
if config["type"] == "endpoint":
272-
reuse_existing_endpoint = config["base_params"]["reuse_existing"]
273-
complete_config_endpoint = all(val not in [None, ""] for val in config["instance"].values())
274-
if reuse_existing_endpoint or complete_config_endpoint:
275-
return InferenceEndpointModelConfig(
276-
name=config["base_params"]["endpoint_name"].replace(".", "-").lower(),
277-
repository=config["base_params"]["model"],
278-
model_dtype=config["base_params"]["dtype"],
279-
revision=config["base_params"]["revision"] or "main",
280-
should_reuse_existing=reuse_existing_endpoint,
281-
accelerator=config["instance"]["accelerator"],
282-
region=config["instance"]["region"],
283-
vendor=config["instance"]["vendor"],
284-
instance_size=config["instance"]["instance_size"],
285-
instance_type=config["instance"]["instance_type"],
286-
)
287-
return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
288-
289-
if config["type"] == "base":
290-
# Tests on the multichoice space parameters
291-
multichoice_continuations_start_space = config["generation"]["multichoice_continuations_start_space"]
292-
no_multichoice_continuations_start_space = config["generation"]["no_multichoice_continuations_start_space"]
293-
if not multichoice_continuations_start_space and not no_multichoice_continuations_start_space:
294-
multichoice_continuations_start_space = None
295-
if multichoice_continuations_start_space and no_multichoice_continuations_start_space:
296-
raise ValueError(
297-
"You cannot force both the multichoice continuations to start with a space and not to start with a space"
298-
)
299-
300-
# Creating optional quantization configuration
301-
if config["base_params"]["dtype"] == "4bit":
302-
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
303-
elif config["base_params"]["dtype"] == "8bit":
304-
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
305-
else:
306-
quantization_config = None
307-
308-
# We extract the model args
309-
args_dict = {k.split("=")[0]: k.split("=")[1] for k in config["base_params"]["model_args"].split(",")}
310-
311-
# We store the relevant other args
312-
args_dict["base_model"] = config["merged_weights"]["base_model"]
313-
args_dict["dtype"] = config["base_params"]["dtype"]
314-
args_dict["accelerator"] = accelerator
315-
args_dict["quantization_config"] = quantization_config
316-
args_dict["batch_size"] = args.override_batch_size
317-
args_dict["multichoice_continuations_start_space"] = multichoice_continuations_start_space
318-
319-
# Keeping only non null params
320-
args_dict = {k: v for k, v in args_dict.items() if v is not None}
321-
322-
if config["merged_weights"]["delta_weights"]:
323-
if config["merged_weights"]["base_model"] is None:
324-
raise ValueError("You need to specify a base model when using delta weights")
325-
return DeltaModelConfig(**args_dict)
326-
if config["merged_weights"]["adapter_weights"]:
327-
if config["merged_weights"]["base_model"] is None:
328-
raise ValueError("You need to specify a base model when using adapter weights")
329-
return AdapterModelConfig(**args_dict)
330-
if config["merged_weights"]["base_model"] not in ["", None]:
331-
raise ValueError("You can't specifify a base model if you are not using delta/adapter weights")
332-
return BaseModelConfig(**args_dict)
333-
334-
raise ValueError(f"Unknown model type in your model config file: {config['type']}")
318+
# Creating optional quantization configuration
319+
if config["base_params"]["dtype"] == "4bit":
320+
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
321+
elif config["base_params"]["dtype"] == "8bit":
322+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
323+
else:
324+
quantization_config = None
325+
326+
# We extract the model args
327+
args_dict = {k.split("=")[0]: k.split("=")[1] for k in config["base_params"]["model_args"].split(",")}
328+
329+
# We store the relevant other args
330+
args_dict["base_model"] = config["merged_weights"]["base_model"]
331+
args_dict["dtype"] = config["base_params"]["dtype"]
332+
args_dict["accelerator"] = accelerator
333+
args_dict["quantization_config"] = quantization_config
334+
args_dict["batch_size"] = args.override_batch_size
335+
args_dict["multichoice_continuations_start_space"] = multichoice_continuations_start_space
336+
337+
# Keeping only non null params
338+
args_dict = {k: v for k, v in args_dict.items() if v is not None}
339+
340+
if config["merged_weights"]["delta_weights"]:
341+
if config["merged_weights"]["base_model"] is None:
342+
raise ValueError("You need to specify a base model when using delta weights")
343+
return DeltaModelConfig(**args_dict)
344+
if config["merged_weights"]["adapter_weights"]:
345+
if config["merged_weights"]["base_model"] is None:
346+
raise ValueError("You need to specify a base model when using adapter weights")
347+
return AdapterModelConfig(**args_dict)
348+
if config["merged_weights"]["base_model"] not in ["", None]:
349+
raise ValueError("You can't specifify a base model if you are not using delta/adapter weights")
350+
return BaseModelConfig(**args_dict)
351+
352+
raise ValueError(f"Unknown model type in your model config file: {config['type']}")

0 commit comments

Comments
 (0)