Skip to content

Commit d2f6090

Browse files
Implement TGI model config from path (#448)
Implement TGI model config from path: ```python TGIModelConfig.from_path(model_config_path) ``` Follow-up to: - #434 Related to: - #439
1 parent 3058a48 commit d2f6090

File tree

5 files changed

+67
-8
lines changed

5 files changed

+67
-8
lines changed

src/lighteval/main_endpoint.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def tgi(
314314
"""
315315
Evaluate models using TGI as backend.
316316
"""
317-
import yaml
318317

319318
from lighteval.logging.evaluation_tracker import EvaluationTracker
320319
from lighteval.models.endpoints.tgi_model import TGIModelConfig
@@ -332,14 +331,8 @@ def tgi(
332331

333332
# TODO (nathan): better handling of model_args
334333
parallelism_manager = ParallelismManager.TGI
335-
with open(model_config_path, "r") as f:
336-
config = yaml.safe_load(f)["model"]
337334

338-
model_config = TGIModelConfig(
339-
inference_server_address=config["instance"]["inference_server_address"],
340-
inference_server_auth=config["instance"]["inference_server_auth"],
341-
model_id=config["instance"]["model_id"],
342-
)
335+
model_config = TGIModelConfig.from_path(model_config_path)
343336

344337
pipeline_params = PipelineParameters(
345338
launcher_type=parallelism_manager,

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def __post_init__(self):
111111

112112
@classmethod
113113
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
114+
"""Load configuration for inference endpoint model from YAML file path.
115+
116+
Args:
117+
path (`str`): Path of the model configuration YAML file.
118+
119+
Returns:
120+
[`InferenceEndpointModelConfig`]: Configuration for inference endpoint model.
121+
"""
114122
import yaml
115123

116124
with open(path, "r") as f:

src/lighteval/models/endpoints/tgi_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@ class TGIModelConfig:
5151
inference_server_auth: str
5252
model_id: str
5353

54+
@classmethod
55+
def from_path(cls, path: str) -> "TGIModelConfig":
56+
"""Load configuration for TGI endpoint model from YAML file path.
57+
58+
Args:
59+
path (`str`): Path of the model configuration YAML file.
60+
61+
Returns:
62+
[`TGIModelConfig`]: Configuration for TGI endpoint model.
63+
"""
64+
import yaml
65+
66+
with open(path, "r") as f:
67+
config = yaml.safe_load(f)["model"]
68+
return cls(**config["instance"])
69+
5470

5571
# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
5672
# the client functions, since they use a different client.
File renamed without changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
from dataclasses import asdict
24+
25+
import pytest
26+
27+
from lighteval.models.endpoints.tgi_model import TGIModelConfig
28+
29+
30+
class TestTGIModelConfig:
31+
@pytest.mark.parametrize(
32+
"config_path, expected_config",
33+
[
34+
(
35+
"examples/model_configs/tgi_model.yaml",
36+
{"inference_server_address": "", "inference_server_auth": None, "model_id": None},
37+
),
38+
],
39+
)
40+
def test_from_path(self, config_path, expected_config):
41+
config = TGIModelConfig.from_path(config_path)
42+
assert asdict(config) == expected_config

0 commit comments

Comments
 (0)