Skip to content

Commit b68d5bc

Browse files
authored
Autoscaling inference endpoints (#412)
* adding better management for restarts and resizes * upgraded autoscale * added pause option * fix to parallelism manager - no need for endpoint
1 parent 3929825 commit b68d5bc

File tree

5 files changed

+230
-93
lines changed

5 files changed

+230
-93
lines changed

examples/model_configs/endpoint_model.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
model:
22
base_params:
3-
endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
4-
model: "meta-llama/Llama-2-7b-hf"
3+
model_name: "meta-llama/Llama-2-7b-hf" # the model name or the endpoint name if reuse_existing is true
54
revision: "main"
65
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
76
reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
model:
2+
base_params:
3+
model_name: "meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"

src/lighteval/main_endpoint.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def inference_endpoint(
190190
] = None,
191191
override_batch_size: Annotated[
192192
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3)
193-
] = -1,
193+
] = None,
194194
job_id: Annotated[
195195
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3)
196196
] = 0,
@@ -203,7 +203,6 @@ def inference_endpoint(
203203
from lighteval.logging.evaluation_tracker import EvaluationTracker
204204
from lighteval.models.model_config import (
205205
InferenceEndpointModelConfig,
206-
InferenceModelConfig,
207206
)
208207
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
209208

@@ -219,38 +218,34 @@ def inference_endpoint(
219218

220219
# TODO (nathan): better handling of model_args
221220

222-
parallelism_manager = ParallelismManager.TGI
221+
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote
223222

224223
with open(model_config_path, "r") as f:
225224
config = yaml.safe_load(f)["model"]
226225

227-
reuse_existing_endpoint = config["base_params"].get("reuse_existing", None)
228-
229-
complete_config_endpoint = all(
230-
val not in [None, ""]
231-
for key, val in config.get("instance", {}).items()
232-
if key not in InferenceEndpointModelConfig.nullable_keys()
226+
# Find a way to add this back
227+
# if config["base_params"].get("endpoint_name", None):
228+
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
229+
all_params = {
230+
"model_name": config["base_params"].get("model_name", None),
231+
"endpoint_name": config["base_params"].get("endpoint_name", None),
232+
"model_dtype": config["base_params"].get("dtype", None),
233+
"revision": config["base_params"].get("revision", None) or "main",
234+
"should_reuse_existing": config["base_params"].get("should_reuse_existing"),
235+
"accelerator": config.get("instance", {}).get("accelerator", None),
236+
"region": config.get("instance", {}).get("region", None),
237+
"vendor": config.get("instance", {}).get("vendor", None),
238+
"instance_size": config.get("instance", {}).get("instance_size", None),
239+
"instance_type": config.get("instance", {}).get("instance_type", None),
240+
"namespace": config.get("instance", {}).get("namespace", None),
241+
"image_url": config.get("instance", {}).get("image_url", None),
242+
"env_vars": config.get("instance", {}).get("env_vars", None),
243+
}
244+
model_config = InferenceEndpointModelConfig(
245+
# We only initialize params which have a non default value
246+
**{k: v for k, v in all_params.items() if v is not None},
233247
)
234248

235-
if reuse_existing_endpoint or complete_config_endpoint:
236-
model_config = InferenceEndpointModelConfig(
237-
name=config["base_params"]["endpoint_name"].replace(".", "-").lower(),
238-
repository=config["base_params"]["model"],
239-
model_dtype=config["base_params"]["dtype"],
240-
revision=config["base_params"]["revision"] or "main",
241-
should_reuse_existing=reuse_existing_endpoint,
242-
accelerator=config["instance"]["accelerator"],
243-
region=config["instance"]["region"],
244-
vendor=config["instance"]["vendor"],
245-
instance_size=config["instance"]["instance_size"],
246-
instance_type=config["instance"]["instance_type"],
247-
namespace=config["instance"]["namespace"],
248-
image_url=config["instance"].get("image_url", None),
249-
env_vars=config["instance"].get("env_vars", None),
250-
)
251-
else:
252-
model_config = InferenceModelConfig(model=config["base_params"]["endpoint_name"])
253-
254249
pipeline_params = PipelineParameters(
255250
launcher_type=parallelism_manager,
256251
env_config=env_config,

src/lighteval/models/endpoint_model.py

Lines changed: 182 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,25 @@
2121
# SOFTWARE.
2222

2323
import asyncio
24+
import re
25+
import time
2426
from typing import Coroutine, List, Optional, Union
2527

28+
import requests
2629
import torch
2730
from huggingface_hub import (
2831
AsyncInferenceClient,
2932
InferenceClient,
3033
InferenceEndpoint,
34+
InferenceEndpointError,
3135
InferenceEndpointTimeoutError,
3236
TextGenerationInputGrammarType,
3337
TextGenerationOutput,
3438
create_inference_endpoint,
3539
get_inference_endpoint,
3640
)
41+
from huggingface_hub.utils import HfHubHTTPError
42+
from requests import ConnectionError
3743
from torch.utils.data import DataLoader
3844
from tqdm import tqdm
3945
from transformers import AutoTokenizer
@@ -53,67 +59,155 @@
5359

5460

5561
BATCH_SIZE = 50
62+
MAX_TIME_FOR_SPINUP = 3600
63+
64+
SORTED_INSTANCE_SIZES = [ # sorted by incremental overall RAM (to load models)
65+
# type, size
66+
("nvidia-a10g", "x1"),
67+
("nvidia-t4", "x4"),
68+
("nvidia-a100", "x1"),
69+
("nvidia-a10g", "x4"),
70+
("nvidia-a100", "x2"),
71+
("nvidia-a100", "x4"),
72+
]
5673

5774

5875
class InferenceEndpointModel(LightevalModel):
5976
"""InferenceEndpointModels can be used both with the free inference client, or with inference
6077
endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation.
6178
"""
6279

63-
def __init__(
80+
def __init__( # noqa: C901
6481
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
6582
) -> None:
6683
self.reuse_existing = getattr(config, "should_reuse_existing", True)
6784
self._max_length = None
85+
self.endpoint = None
86+
self.model_name = None
6887
if isinstance(config, InferenceEndpointModelConfig):
69-
if config.should_reuse_existing:
70-
self.endpoint = get_inference_endpoint(
71-
name=config.name, token=env_config.token, namespace=config.namespace
88+
if config.instance_type and config.instance_size and config.vendor and config.region:
89+
vendor, region, instance_type, instance_size = (
90+
config.vendor,
91+
config.region,
92+
config.instance_type,
93+
config.instance_size,
7294
)
7395
else:
74-
self.endpoint: InferenceEndpoint = create_inference_endpoint(
75-
name=config.name,
76-
namespace=config.namespace,
77-
repository=config.repository,
78-
revision=config.revision,
79-
framework=config.framework,
80-
task="text-generation",
81-
accelerator=config.accelerator,
82-
vendor=config.vendor,
83-
region=config.region,
84-
type=config.endpoint_type,
85-
instance_size=config.instance_size,
86-
instance_type=config.instance_type,
87-
token=env_config.token,
88-
custom_image={
89-
"health_route": "/health",
90-
"env": {
91-
# Documentaiton: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
92-
"MAX_BATCH_PREFILL_TOKENS": "2048",
93-
"MAX_INPUT_LENGTH": "2047",
94-
"MAX_TOTAL_TOKENS": "2048",
95-
"MODEL_ID": "/repository",
96-
"HF_MODEL_TRUST_REMOTE_CODE": "true",
97-
**config.get_dtype_args(),
98-
**config.get_custom_env_vars(),
99-
},
100-
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"),
101-
},
102-
)
103-
hlog("Deploying your endpoint. Please wait.")
104-
try:
105-
self.endpoint.wait(timeout=600) # Waits for the endpoint to be deployed
106-
except InferenceEndpointTimeoutError as e:
107-
hlog_err("Endpoint did not start within 10 minutes, there was a timeout.")
108-
raise e
96+
try:
97+
vendor, region, instance_type, instance_size = InferenceEndpointModel.get_suggested_model_config(
98+
config.model_name
99+
)
100+
except Exception:
101+
vendor, region, instance_type, instance_size = (
102+
"aws",
103+
"us-east-1",
104+
*InferenceEndpointModel.get_larger_hardware_suggestion(),
105+
)
106+
107+
must_scaleup_endpoint = False
108+
timer_start = time.time()
109+
# Endpoint names do not allow special characters
110+
endpoint_name = config.endpoint_name or re.sub(
111+
"[^a-zA-Z0-9-]", "-", config.model_name.lower() + "-lighteval"
112+
)
113+
# If no endpoint or endpoint not running, and we're below an hour
114+
while (self.endpoint is None or self.endpoint.status != "running") and (
115+
time.time() - timer_start < MAX_TIME_FOR_SPINUP
116+
):
117+
try:
118+
if self.endpoint is None: # Endpoint does not exist yet locally
119+
if not config.should_reuse_existing: # New endpoint
120+
hlog("Creating endpoint.")
121+
self.endpoint: InferenceEndpoint = create_inference_endpoint(
122+
name=endpoint_name,
123+
namespace=config.namespace,
124+
repository=config.model_name,
125+
revision=config.revision,
126+
framework=config.framework,
127+
task="text-generation",
128+
accelerator=config.accelerator,
129+
type=config.endpoint_type,
130+
vendor=vendor,
131+
region=region,
132+
instance_size=instance_size,
133+
instance_type=instance_type,
134+
token=env_config.token,
135+
custom_image={
136+
"health_route": "/health",
137+
"env": {
138+
# Documentation: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher
139+
"MAX_BATCH_PREFILL_TOKENS": "2048",
140+
"MAX_INPUT_LENGTH": "2047",
141+
"MAX_TOTAL_TOKENS": "2048",
142+
"MODEL_ID": "/repository",
143+
"HF_MODEL_TRUST_REMOTE_CODE": "true",
144+
**config.get_dtype_args(),
145+
**config.get_custom_env_vars(),
146+
},
147+
"url": (
148+
config.image_url or "ghcr.io/huggingface/text-generation-inference:latest"
149+
),
150+
},
151+
)
152+
else: # Endpoint exists
153+
hlog("Reusing existing endpoint.")
154+
self.endpoint = get_inference_endpoint(
155+
name=endpoint_name, token=env_config.token, namespace=config.namespace
156+
)
157+
158+
else:
159+
# Endpoint exists locally but either failed (and most likely it must be scaled up)
160+
if must_scaleup_endpoint:
161+
hlog("Rescaling existing endpoint.")
162+
self.endpoint.update(instance_size=instance_size, instance_type=instance_type)
163+
must_scaleup_endpoint = False
164+
# or we got a connection error, in which case we do nothing and just wait at the next step
165+
166+
# Waits for the endpoint to be deployed - we could also check for the status in updating', 'pending', 'initializing'
167+
hlog("Trying to deploy your endpoint. Please wait for 10 min.")
168+
self.endpoint.wait(timeout=600, refresh_every=60) # We wait for 10 min
169+
except InferenceEndpointError as e:
170+
instance_type, instance_size = InferenceEndpointModel.get_larger_hardware_suggestion(
171+
instance_type, instance_size
172+
)
173+
must_scaleup_endpoint = True
174+
175+
hlog(
176+
f"Endpoint failed to start on current hardware with error {e}. Trying to autoscale to ({instance_type}, {instance_size})."
177+
)
178+
except InferenceEndpointTimeoutError as e:
179+
hlog_err("Endpoint did not start within 30 minutes, there was a timeout. Please inspect the logs.")
180+
raise e
181+
except HfHubHTTPError as e:
182+
# The endpoint actually already exists, we'll spin it up instead of trying to create a new one
183+
if "409 Client Error: Conflict for url:" in str(e):
184+
config.endpoint_name = endpoint_name
185+
config.should_reuse_existing = True
186+
# Requested resources are not available
187+
elif "Bad Request: Compute instance not available yet" in str(e):
188+
hlog_err(
189+
"The hardware combination you are requesting does not seem to be available: ({instance_type}, {instance_size}, {config.region})."
190+
)
191+
raise e
192+
# User account does not have access to requested resources
193+
elif "Conflict: Quota exceeded" in str(e):
194+
raise e
195+
except ConnectionError as e:
196+
hlog_err(f"Connection failed with error {e}. Retrying")
197+
198+
if not self.endpoint.status == "running":
199+
raise Exception("Did not manage to start endpoint within the elapsed time and on suggested hardware.")
200+
109201
hlog("Endpoint successfully deployed!")
110-
self.name = config.repository
202+
self.endpoint_name = config.endpoint_name
203+
self.name = self.endpoint.repository
111204
self.revision = self.endpoint.revision
112205
self.async_client: AsyncInferenceClient = self.endpoint.async_client
113206
self.client: InferenceClient = self.endpoint.client
114207

115208
else: # Free inference client
116209
self.endpoint = None
210+
self.endpoint_name = None
117211
self.name = config.model
118212
self.revision = "default"
119213
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
@@ -131,6 +225,43 @@ def __init__(
131225
model_size=-1,
132226
)
133227

228+
@staticmethod
229+
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
230+
cur_instance_ix = -1
231+
try:
232+
if cur_instance_type and cur_instance_size:
233+
cur_instance_ix = SORTED_INSTANCE_SIZES.index((cur_instance_type, cur_instance_size))
234+
new_instance_type = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][0]
235+
new_instance_size = SORTED_INSTANCE_SIZES[cur_instance_ix + 1][1]
236+
return new_instance_type, new_instance_size
237+
except ValueError:
238+
raise Exception(
239+
f"Problem when scaling endpoint: the current instance combination ({cur_instance_type}, {cur_instance_size}) is unknown. Can't scale it up."
240+
)
241+
except IndexError:
242+
raise Exception(
243+
"To avoid accidental costs, we do not upgrade the current endpoint above 4 a100 automatically, please request it explicitely."
244+
)
245+
246+
@staticmethod
247+
def get_suggested_model_config(model_repo):
248+
# Code from https://huggingface.co/spaces/huggingface/dedicated-endpoint-snooper/blob/main/app.py
249+
# Example of the suggestedCompute value: 'aws-us-east-1-nvidia-l4-x1'
250+
# -> aws us-east-1 nvidia-l4 x1
251+
url = f"https://ui.endpoints.huggingface.co/api/configuration?model_id={model_repo}"
252+
response = requests.get(url)
253+
config = response.json()
254+
255+
suggested_compute = config["suggestedCompute"]
256+
suggested_vendor = suggested_compute.split("-")[0]
257+
if suggested_vendor == "azure":
258+
suggested_region = suggested_compute.split("-")[1]
259+
else:
260+
suggested_region = "-".join(suggested_compute.split("-")[1:4])
261+
suggested_instance = "-".join(suggested_compute.split("-")[-3:-1])
262+
suggested_size = suggested_compute.split("-")[-1]
263+
return suggested_vendor, suggested_region, suggested_instance, suggested_size
264+
134265
@property
135266
def tokenizer(self):
136267
return self._tokenizer
@@ -144,11 +275,17 @@ def disable_tqdm(self) -> bool:
144275
False # no accelerator = this is the main process
145276

146277
def cleanup(self):
147-
if self.endpoint is not None and not self.reuse_existing:
148-
self.endpoint.delete()
149-
hlog_warn(
150-
"You deleted your endpoint after using it. You'll need to create it again if you need to reuse it."
151-
)
278+
if self.endpoint is not None:
279+
if self.reuse_existing:
280+
self.endpoint.pause()
281+
hlog_warn(
282+
"Since your endpoint was existing before, we did not delete it, but paused it instead. You might want to delete it if you're done using it."
283+
)
284+
else:
285+
self.endpoint.delete()
286+
hlog_warn(
287+
"We deleted the spinned up endpoint after using it. You'll need to create it again if you need to reuse it."
288+
)
152289

153290
@property
154291
def max_length(self):

0 commit comments

Comments
 (0)