Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/model_configs/endpoint_model.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
model:
base_params:
endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
model: "meta-llama/Llama-2-7b-hf"
model_name: "meta-llama/Llama-2-7b-hf" # the model name or the endpoint name if reuse_existing is true
revision: "main"
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation
Expand Down
3 changes: 3 additions & 0 deletions examples/model_configs/endpoint_model_lite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model:
base_params:
model_name: "meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"
51 changes: 23 additions & 28 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def inference_endpoint(
] = None,
override_batch_size: Annotated[
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3)
] = -1,
] = None,
job_id: Annotated[
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3)
] = 0,
Expand All @@ -203,7 +203,6 @@ def inference_endpoint(
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_config import (
InferenceEndpointModelConfig,
InferenceModelConfig,
)
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

Expand All @@ -219,38 +218,34 @@ def inference_endpoint(

# TODO (nathan): better handling of model_args

parallelism_manager = ParallelismManager.TGI
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

reuse_existing_endpoint = config["base_params"].get("reuse_existing", None)

complete_config_endpoint = all(
val not in [None, ""]
for key, val in config.get("instance", {}).items()
if key not in InferenceEndpointModelConfig.nullable_keys()
# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
all_params = {
"model_name": config["base_params"].get("model_name", None),
"endpoint_name": config["base_params"].get("endpoint_name", None),
"model_dtype": config["base_params"].get("dtype", None),
"revision": config["base_params"].get("revision", None) or "main",
"should_reuse_existing": config["base_params"].get("should_reuse_existing"),
"accelerator": config.get("instance", {}).get("accelerator", None),
"region": config.get("instance", {}).get("region", None),
"vendor": config.get("instance", {}).get("vendor", None),
"instance_size": config.get("instance", {}).get("instance_size", None),
"instance_type": config.get("instance", {}).get("instance_type", None),
"namespace": config.get("instance", {}).get("namespace", None),
"image_url": config.get("instance", {}).get("image_url", None),
"env_vars": config.get("instance", {}).get("env_vars", None),
}
model_config = InferenceEndpointModelConfig(
# We only initialize params which have a non default value
**{k: v for k, v in all_params.items() if v is not None},
)

if reuse_existing_endpoint or complete_config_endpoint:
model_config = InferenceEndpointModelConfig(
name=config["base_params"]["endpoint_name"].replace(".", "-").lower(),
repository=config["base_params"]["model"],
model_dtype=config["base_params"]["dtype"],
revision=config["base_params"]["revision"] or "main",
should_reuse_existing=reuse_existing_endpoint,
accelerator=config["instance"]["accelerator"],
region=config["instance"]["region"],
vendor=config["instance"]["vendor"],
instance_size=config["instance"]["instance_size"],
instance_type=config["instance"]["instance_type"],
namespace=config["instance"]["namespace"],
image_url=config["instance"].get("image_url", None),
env_vars=config["instance"].get("env_vars", None),
)
else:
model_config = InferenceModelConfig(model=config["base_params"]["endpoint_name"])

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
env_config=env_config,
Expand Down
227 changes: 182 additions & 45 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@
# SOFTWARE.

import asyncio
import re
import time
from typing import Coroutine, List, Optional, Union

import requests
import torch
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
InferenceEndpoint,
InferenceEndpointError,
InferenceEndpointTimeoutError,
TextGenerationInputGrammarType,
TextGenerationOutput,
create_inference_endpoint,
get_inference_endpoint,
)
from huggingface_hub.utils import HfHubHTTPError
from requests import ConnectionError
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer
Expand All @@ -53,67 +59,155 @@


BATCH_SIZE = 50
MAX_TIME_FOR_SPINUP = 3600

SORTED_INSTANCE_SIZES = [ # sorted by incremental overall RAM (to load models)
# type, size
("nvidia-a10g", "x1"),
("nvidia-t4", "x4"),
("nvidia-a100", "x1"),
("nvidia-a10g", "x4"),
("nvidia-a100", "x2"),
("nvidia-a100", "x4"),
]


class InferenceEndpointModel(LightevalModel):
"""InferenceEndpointModels can be used both with the free inference client, or with inference
endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation.
"""

def __init__(
def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "should_reuse_existing", True)
self._max_length = None
self.endpoint = None
self.model_name = None
if isinstance(config, InferenceEndpointModelConfig):
if config.should_reuse_existing:
self.endpoint = get_inference_endpoint(
name=config.name, token=env_config.token, namespace=config.namespace
if config.instance_type and config.instance_size and config.vendor and config.region:
vendor, region, instance_type, instance_size = (
config.vendor,
config.region,
config.instance_type,
config.instance_size,
)
else:
self.endpoint: InferenceEndpoint = create_inference_endpoint(
name=config.name,
namespace=config.namespace,
repository=config.repository,
revision=config.revision,
framework=config.framework,
task="text-generation",
accelerator=config.accelerator,
vendor=config.vendor,
region=config.region,
type=config.endpoint_type,
instance_size=config.instance_size,
instance_type=config.instance_type,
token=env_config.token,
custom_image={
"health_route": "/health",
"env": {
# Documentaiton: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
"MAX_BATCH_PREFILL_TOKENS": "2048",
"MAX_INPUT_LENGTH": "2047",
"MAX_TOTAL_TOKENS": "2048",
"MODEL_ID": "/repository",
"HF_MODEL_TRUST_REMOTE_CODE": "true",
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"),
},
)
hlog("Deploying your endpoint. Please wait.")
try:
self.endpoint.wait(timeout=600) # Waits for the endpoint to be deployed
except InferenceEndpointTimeoutError as e:
hlog_err("Endpoint did not start within 10 minutes, there was a timeout.")
raise e
try:
vendor, region, instance_type, instance_size = InferenceEndpointModel.get_suggested_model_config(
config.model_name
)
except Exception:
vendor, region, instance_type, instance_size = (
"aws",
"us-east-1",
*InferenceEndpointModel.get_larger_hardware_suggestion(),
)

must_scaleup_endpoint = False
timer_start = time.time()
# Endpoint names do not allow special characters
endpoint_name = config.endpoint_name or re.sub(
"[^a-zA-Z0-9-]", "-", config.model_name.lower() + "-lighteval"
)
# If no endpoint or endpoint not running, and we're below an hour
while (self.endpoint is None or self.endpoint.status != "running") and (
time.time() - timer_start < MAX_TIME_FOR_SPINUP
):
try:
if self.endpoint is None: # Endpoint does not exist yet locally
if not config.should_reuse_existing: # New endpoint
hlog("Creating endpoint.")
self.endpoint: InferenceEndpoint = create_inference_endpoint(
name=endpoint_name,
namespace=config.namespace,
repository=config.model_name,
revision=config.revision,
framework=config.framework,
task="text-generation",
accelerator=config.accelerator,
type=config.endpoint_type,
vendor=vendor,
region=region,
instance_size=instance_size,
instance_type=instance_type,
token=env_config.token,
custom_image={
"health_route": "/health",
"env": {
# Documentation: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
"MAX_BATCH_PREFILL_TOKENS": "2048",
"MAX_INPUT_LENGTH": "2047",
"MAX_TOTAL_TOKENS": "2048",
"MODEL_ID": "/repository",
"HF_MODEL_TRUST_REMOTE_CODE": "true",
**config.get_dtype_args(),
**config.get_custom_env_vars(),
},
"url": (
config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"
),
},
)
else: # Endpoint exists
hlog("Reusing existing endpoint.")
self.endpoint = get_inference_endpoint(
name=endpoint_name, token=env_config.token, namespace=config.namespace
)

else:
# Endpoint exists locally but either failed (and most likely it must be scaled up)
if must_scaleup_endpoint:
hlog("Rescaling existing endpoint.")
self.endpoint.update(instance_size=instance_size, instance_type=instance_type)
must_scaleup_endpoint = False
# or we got a connection error, in which case we do nothing and just wait at the next step

# Waits for the endpoint to be deployed - we could also check for the status in updating', 'pending', 'initializing'
hlog("Trying to deploy your endpoint. Please wait for 10 min.")
self.endpoint.wait(timeout=600, refresh_every=60) # We wait for 10 min
except InferenceEndpointError as e:
instance_type, instance_size = InferenceEndpointModel.get_larger_hardware_suggestion(
instance_type, instance_size
)
must_scaleup_endpoint = True

hlog(
f"Endpoint failed to start on current hardware with error {e}. Trying to autoscale to ({instance_type}, {instance_size})."
)
except InferenceEndpointTimeoutError as e:
hlog_err("Endpoint did not start within 30 minutes, there was a timeout. Please inspect the logs.")
raise e
except HfHubHTTPError as e:
# The endpoint actually already exists, we'll spin it up instead of trying to create a new one
if "409 Client Error: Conflict for url:" in str(e):
config.endpoint_name = endpoint_name
config.should_reuse_existing = True
# Requested resources are not available
elif "Bad Request: Compute instance not available yet" in str(e):
hlog_err(
"The hardware combination you are requesting does not seem to be available: ({instance_type}, {instance_size}, {config.region})."
)
raise e
# User account does not have access to requested resources
elif "Conflict: Quota exceeded" in str(e):
raise e
except ConnectionError as e:
hlog_err(f"Connection failed with error {e}. Retrying")

if not self.endpoint.status == "running":
raise Exception("Did not manage to start endpoint within the elapsed time and on suggested hardware.")

hlog("Endpoint successfully deployed!")
self.name = config.repository
self.endpoint_name = config.endpoint_name
self.name = self.endpoint.repository
self.revision = self.endpoint.revision
self.async_client: AsyncInferenceClient = self.endpoint.async_client
self.client: InferenceClient = self.endpoint.client

else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
Expand All @@ -131,6 +225,43 @@ def __init__(
model_size=-1,
)

@staticmethod
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
cur_instance_ix = -1
try:
if cur_instance_type and cur_instance_size:
cur_instance_ix = SORTED_INSTANCE_SIZES.index((cur_instance_type, cur_instance_size))
new_instance_type = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][0]
new_instance_size = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][1]
return new_instance_type, new_instance_size
except ValueError:
raise Exception(
f"Problem when scaling endpoint: the current instance combination ({cur_instance_type}, {cur_instance_size}) is unknown. Can't scale it up."
)
except IndexError:
raise Exception(
"To avoid accidental costs, we do not upgrade the current endpoint above 4 a100 automatically, please request it explicitely."
)

@staticmethod
def get_suggested_model_config(model_repo):
# Code from https://huggingface.co/spaces/huggingface/dedicated-endpoint-snooper/blob/main/app.py
# Example of the suggestedCompute value: 'aws-us-east-1-nvidia-l4-x1'
# -> aws us-east-1 nvidia-l4 x1
url = f"https://ui.endpoints.huggingface.co/api/configuration?model_id={model_repo}"
response = requests.get(url)
config = response.json()

suggested_compute = config["suggestedCompute"]
suggested_vendor = suggested_compute.split("-")[0]
if suggested_vendor == "azure":
suggested_region = suggested_compute.split("-")[1]
else:
suggested_region = "-".join(suggested_compute.split("-")[1:4])
suggested_instance = "-".join(suggested_compute.split("-")[-3:-1])
suggested_size = suggested_compute.split("-")[-1]
return suggested_vendor, suggested_region, suggested_instance, suggested_size

@property
def tokenizer(self):
return self._tokenizer
Expand All @@ -144,11 +275,17 @@ def disable_tqdm(self) -> bool:
False # no accelerator = this is the main process

def cleanup(self):
if self.endpoint is not None and not self.reuse_existing:
self.endpoint.delete()
hlog_warn(
"You deleted your endpoint after using it. You'll need to create it again if you need to reuse it."
)
if self.endpoint is not None:
if self.reuse_existing:
self.endpoint.pause()
hlog_warn(
"Since your endpoint was existing before, we did not delete it, but paused it instead. You might want to delete it if you're done using it."
)
else:
self.endpoint.delete()
hlog_warn(
"We deleted the spinned up endpoint after using it. You'll need to create it again if you need to reuse it."
)

@property
def max_length(self):
Expand Down
Loading
Loading