Skip to content

Commit 7fcaab3

Browse files
Updated tgi_model and added parameters for endpoint_model (#208)
* Added image url parameter * Fixed up tgi model config * Undid tgi available check * Adjust tgi parameter names, and checked for attr existence * Fixed task Id in argparse * Removed obfuscation from private functions, to allow inheritance to override * Updated tgi model to inherit from endpoint and just modify client calls * Added option to specify model id in config for tgi model * Added option to specify custom env vars * Updated env vras * Applied ruff format * Added docs + readme * Ruff format
1 parent a3d1eea commit 7fcaab3

File tree

9 files changed

+109
-111
lines changed

9 files changed

+109
-111
lines changed

README.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ accelerate launch --multi_gpu --num_processes=<num_gpus> run_evals_accelerate.py
139139
--output_dir output_dir
140140
```
141141

142-
Examples of possible configuration files are provided in `examples/model_configs`.
142+
You can find the template of the expected model configuration in [examples/model_configs/base_model.yaml_](./examples/model_configs/base_model.yaml).
143143

144144
### Evaluating a large model with pipeline parallelism
145145

@@ -182,6 +182,25 @@ python run_evals_accelerate.py \
182182
--output_dir output_dir
183183
```
184184

185+
### Evaluate the model on a server/container.
186+
187+
An alternative to launching the evaluation locally is to serve the model on a TGI-compatible server/container and then run the evaluation by sending requests to the server. The command is the same as before, except you specify a path to a yaml config file (detailed below):
188+
189+
```shell
190+
python run_evals_accelerate.py \
191+
--model_config_path="/path/to/config/file"\
192+
--tasks <task parameters> \
193+
--output_dir output_dir
194+
```
195+
196+
There are two types of configuration files that can be provided for running on the server:
197+
198+
1. [endpoint_model.yaml](./examples/model_configs/endpoint_model.yaml): This configuration allows you to launch the model using [HuggingFace's Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). You can specify in the configuration file all the relevant parameters, and then `lighteval` will automatically deploy the endpoint, run the evaluation, and finally delete the endpoint (unless you specify an endpoint that was already launched, in which case the endpoint won't be deleted afterwards).
199+
200+
2. [tgi_model.yaml](./examples/model_configs/tgi_model.yaml): This configuration lets you specify the URL of a model running in a TGI container, such as one deployed on HuggingFace's serverless inference.
201+
202+
Templates for these configurations can be found in [examples/model_configs](./examples/model_configs/).
203+
185204
### Evaluate a model on extended, community, or custom tasks.
186205

187206
Independently of the default tasks provided in `lighteval` that you will find in the `tasks_table.jsonl` file, you can use `lighteval` to evaluate models on tasks that require special processing (or have been added by the community). These tasks have their own evaluation suites and are defined as follows:
@@ -190,7 +209,6 @@ Independently of the default tasks provided in `lighteval` that you will find in
190209
* `community`: tasks that have been added by the community. See the [`community_tasks`](./community_tasks) folder for examples.
191210
* `custom`: tasks that are defined locally and not present in the core library. Use this suite if you want to experiment with designing a special metric or task.
192211

193-
194212
For example, to run an extended task like `ifeval`, you can run:
195213
```shell
196214
python run_evals_accelerate.py \

examples/model_configs/endpoint_model.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ model:
55
model: "meta-llama/Llama-2-7b-hf"
66
revision: "main"
77
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
8-
reuse_existing: false # if true, ignore all params in instance
8+
reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation
99
instance:
1010
accelerator: "gpu"
1111
region: "eu-west-1"
@@ -15,5 +15,8 @@ model:
1515
framework: "pytorch"
1616
endpoint_type: "protected"
1717
namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace
18+
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
19+
env_vars:
20+
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
1821
generation:
1922
add_special_tokens: true

examples/model_configs/tgi_model.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ model:
33
instance:
44
inference_server_address: ""
55
inference_server_auth: null
6+
model_id: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory

run_evals_accelerate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
""" Example run command:
23+
"""Example run command:
2424
accelerate config
2525
accelerate launch run_evals_accelerate.py --tasks="leaderboard|hellaswag|5|1" --output_dir "/scratch/evals" --model_args "pretrained=gpt2"
2626
"""
27+
2728
import argparse
2829

2930
from lighteval.main_accelerate import CACHE_DIR, main
@@ -70,7 +71,7 @@ def get_parser():
7071
"--tasks",
7172
type=str,
7273
default=None,
73-
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
74+
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5|0' or path to a texte file with a list of tasks",
7475
)
7576
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
7677
return parser

src/lighteval/models/endpoint_model.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def __init__(
9292
"MAX_TOTAL_TOKENS": "2048",
9393
"MODEL_ID": "/repository",
9494
**config.get_dtype_args(),
95+
**config.get_custom_env_vars(),
9596
},
96-
"url": "ghcr.io/huggingface/text-generation-inference:1.1.0",
97+
"url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:1.1.0"),
9798
},
9899
)
99100
hlog("Deploying your endpoint. Please wait.")
@@ -149,7 +150,7 @@ def max_length(self):
149150
self._max_length = 2048
150151
return self._max_length
151152

152-
def __async_process_request(
153+
def _async_process_request(
153154
self, context: str, stop_tokens: list[str], max_tokens: int
154155
) -> Coroutine[None, list[TextGenerationOutput], str]:
155156
# Todo: add an option to launch with conversational instead for chat prompts
@@ -165,7 +166,7 @@ def __async_process_request(
165166

166167
return generated_text
167168

168-
def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput:
169+
def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput:
169170
# Todo: add an option to launch with conversational instead for chat prompts
170171
# https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
171172
generated_text = self.client.text_generation(
@@ -179,13 +180,13 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in
179180

180181
return generated_text
181182

182-
async def __async_process_batch_generate(
183+
async def _async_process_batch_generate(
183184
self,
184185
requests: list[GreedyUntilRequest],
185186
) -> list[TextGenerationOutput]:
186187
return await asyncio.gather(
187188
*[
188-
self.__async_process_request(
189+
self._async_process_request(
189190
context=request.context,
190191
stop_tokens=as_list(request.stop_sequence),
191192
max_tokens=request.generation_size,
@@ -194,25 +195,25 @@ async def __async_process_batch_generate(
194195
]
195196
)
196197

197-
def __process_batch_generate(
198+
def _process_batch_generate(
198199
self,
199200
requests: list[GreedyUntilRequest],
200201
) -> list[TextGenerationOutput]:
201202
return [
202-
self.__process_request(
203+
self._process_request(
203204
context=request.context,
204205
stop_tokens=as_list(request.stop_sequence),
205206
max_tokens=request.generation_size,
206207
)
207208
for request in requests
208209
]
209210

210-
async def __async_process_batch_logprob(
211+
async def _async_process_batch_logprob(
211212
self, requests: list[LoglikelihoodRequest], rolling: bool = False
212213
) -> list[TextGenerationOutput]:
213214
return await asyncio.gather(
214215
*[
215-
self.__async_process_request(
216+
self._async_process_request(
216217
context=request.context if rolling else request.context + request.choice,
217218
stop_tokens=[],
218219
max_tokens=1,
@@ -221,11 +222,11 @@ async def __async_process_batch_logprob(
221222
]
222223
)
223224

224-
def __process_batch_logprob(
225+
def _process_batch_logprob(
225226
self, requests: list[LoglikelihoodRequest], rolling: bool = False
226227
) -> list[TextGenerationOutput]:
227228
return [
228-
self.__process_request(
229+
self._process_request(
229230
context=request.context if rolling else request.context + request.choice,
230231
stop_tokens=[],
231232
max_tokens=1,
@@ -267,9 +268,9 @@ def greedy_until(
267268
)
268269

269270
if self.use_async:
270-
responses = asyncio.run(self.__async_process_batch_generate(batch))
271+
responses = asyncio.run(self._async_process_batch_generate(batch))
271272
else:
272-
responses = self.__process_batch_generate(batch)
273+
responses = self._process_batch_generate(batch)
273274
for response in responses:
274275
results.append(
275276
GenerateReturn(
@@ -303,9 +304,9 @@ def loglikelihood(
303304

304305
for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm):
305306
if self.use_async:
306-
responses = asyncio.run(self.__async_process_batch_logprob(batch))
307+
responses = asyncio.run(self._async_process_batch_logprob(batch))
307308
else:
308-
responses = self.__process_batch_logprob(batch)
309+
responses = self._process_batch_logprob(batch)
309310
for cur_request, response in zip(batch, responses):
310311
cont_toks = torch.tensor(cur_request.tokenized_continuation)
311312
len_choice = len(cont_toks)
@@ -351,9 +352,9 @@ def loglikelihood_rolling(
351352
dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm
352353
):
353354
if self.use_async:
354-
responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True))
355+
responses = asyncio.run(self._async_process_batch_logprob(batch, rolling=True))
355356
else:
356-
responses = self.__process_batch_logprob(batch, rolling=True)
357+
responses = self._process_batch_logprob(batch, rolling=True)
357358
for response in responses:
358359
logits = [t.logprob for t in response.details.tokens[:-1]]
359360

src/lighteval/models/model_config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def init_configs(self, env_config: EnvConfig):
200200
class TGIModelConfig:
201201
inference_server_address: str
202202
inference_server_auth: str
203+
model_id: str
203204

204205

205206
@dataclass
@@ -224,6 +225,8 @@ class InferenceEndpointModelConfig:
224225
add_special_tokens: bool = True
225226
revision: str = "main"
226227
namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
228+
image_url: str = None
229+
env_vars: dict = None
227230

228231
def get_dtype_args(self) -> Dict[str, str]:
229232
model_dtype = self.model_dtype.lower()
@@ -237,14 +240,17 @@ def get_dtype_args(self) -> Dict[str, str]:
237240
return {"DTYPE": model_dtype}
238241
return {}
239242

243+
def get_custom_env_vars(self) -> Dict[str, str]:
244+
return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {}
245+
240246
@staticmethod
241247
def nullable_keys() -> list[str]:
242248
"""
243249
Returns the list of optional keys in an endpoint model configuration. By default, the code requires that all the
244250
keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys
245251
that are not required and can remain None.
246252
"""
247-
return ["namespace"]
253+
return ["namespace", "env_vars", "image_url"]
248254

249255

250256
def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
@@ -271,16 +277,17 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
271277

272278
return BaseModelConfig(**args_dict)
273279

274-
if args.model_config:
280+
if hasattr(args, "model_config") and args.model_config:
275281
config = args.model_config["model"]
276282
else:
277283
with open(args.model_config_path, "r") as f:
278284
config = yaml.safe_load(f)["model"]
279285

280286
if config["type"] == "tgi":
281287
return TGIModelConfig(
282-
inference_server_address=args["instance"]["inference_server_address"],
283-
inference_server_auth=args["instance"]["inference_server_auth"],
288+
inference_server_address=config["instance"]["inference_server_address"],
289+
inference_server_auth=config["instance"]["inference_server_auth"],
290+
model_id=config["instance"]["model_id"],
284291
)
285292

286293
if config["type"] == "endpoint":
@@ -303,6 +310,8 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
303310
instance_size=config["instance"]["instance_size"],
304311
instance_type=config["instance"]["instance_type"],
305312
namespace=config["instance"]["namespace"],
313+
image_url=config["instance"].get("image_url", None),
314+
env_vars=config["instance"].get("env_vars", None),
306315
)
307316
return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
308317

src/lighteval/models/model_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,12 @@ def load_model_with_tgi(config: TGIModelConfig):
8888
raise ImportError(NO_TGI_ERROR_MSG)
8989

9090
hlog(f"Load model from inference server: {config.inference_server_address}")
91-
model = ModelClient(address=config.inference_server_address, auth_token=config.inference_server_auth)
91+
model = ModelClient(
92+
address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id
93+
)
9294
model_name = str(model.model_info["model_id"])
9395
model_sha = model.model_info["model_sha"]
94-
model_precision = model.model_info["dtype"]
96+
model_precision = model.model_info["model_dtype"]
9597
model_size = -1
9698
model_info = ModelInfo(
9799
model_name=model_name,

0 commit comments

Comments
 (0)