Skip to content

Commit 21934d5

Browse files
add vlmm backend (#274)
what this PR does: - adds vllm as backend for faster inference. how to use: ``` lighteval accelerate --model_args="pretrained=meta-llama/Meta-Llama-3.1-8B-Instruct,dtype=bfloat16,vllm,data_parallel_size=2" use_chat_template --tasks "leaderboard|arc:challenge|0|0,extended|ifeval|0|0,lighteval|gsm8k|5|1" output_dir="./evals/" ``` --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 8c787df commit 21934d5

File tree

8 files changed

+436
-6
lines changed

8 files changed

+436
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ nanotron = [
8787
"tensorboardX"
8888
]
8989
tensorboardX = ["tensorboardX"]
90+
vllm = ["vllm", "ray", "more_itertools"]
9091
quality = ["ruff==v0.2.2","pre-commit"]
9192
tests = ["pytest==7.4.0"]
9293
dev = ["lighteval[accelerate,quality,tests]"]

src/lighteval/data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@ def __len__(self) -> int:
161161
"""
162162
return self.split_end - self.split_start
163163

164+
def __iter__(self) -> Iterator[Request]:
165+
"""
166+
Iterator that yields the items of the dataset depending on the split we
167+
are currently in. For instance, if we are in split 0, we will get the
168+
items from index 0 to self.split_size, if we are in split 1, we will get
169+
the items from index self.split_size to 2 * self.split_size, etc. Used
170+
for dynamic batching.
171+
172+
Yields:
173+
Any: The items of the dataset.
174+
"""
175+
for i in range(self.split_start, self.split_end):
176+
yield self.sorted_data[i]
177+
164178
def _sorting_criteria(self, request) -> int:
165179
raise NotImplementedError()
166180

src/lighteval/models/base_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
hlog(f"Using Data Parallelism, putting model on device {self._device}")
9494
self.model = self.model.to(self._device)
9595
if config.compile:
96+
hlog("Compiling the model")
9697
self.model.model.compile()
9798

9899
self.model_name = _simplify_name(config.pretrained)
@@ -549,9 +550,9 @@ def greedy_until(
549550
tokenized = self.tokenizer(
550551
context,
551552
truncation="longest_first", # we truncate to the model max length if needed
552-
padding="longest", # we pad to the longest sequence
553+
padding="max_length", # we pad to the longest sequence
553554
return_tensors="pt",
554-
max_length=self.max_length - 1, # we always allow minimum one token of generation
555+
max_length=max_context_continuation_size_allowed, # we always allow minimum one token of generation
555556
add_special_tokens=self.add_special_tokens,
556557
).to(self.device)
557558

@@ -573,7 +574,10 @@ def greedy_until(
573574
if max_new_tokens is None: # If generation size is not set, we go all the way
574575
max_new_tokens = self.max_length - context_size
575576
else:
577+
print(self.max_length, context_size, max_new_tokens)
576578
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
579+
if max_new_tokens < 1:
580+
max_new_tokens = 1
577581

578582
prepared_batch = Batch(
579583
input_ids=tokenized["input_ids"],

src/lighteval/models/model_config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,25 @@ def init_configs(self, env_config: EnvConfig):
204204
return self._init_configs(self.base_model, env_config)
205205

206206

207+
@dataclass
208+
class VLLMModelConfig:
209+
pretrained: str
210+
gpu_memory_utilisation: float = 0.8
211+
batch_size: int = -1
212+
revision: str = "main"
213+
dtype: str | None = None
214+
tensor_parallel_size: int = 1
215+
data_parallel_size: int = 1
216+
max_model_length: int = 1024
217+
swap_space: int = 4 # CPU swap space size (GiB) per GPU.
218+
seed: int = 1234
219+
trust_remote_code: bool = False
220+
use_chat_template: bool = False
221+
add_special_tokens: bool = True
222+
multichoice_continuations_start_space: bool = True
223+
subfolder: Optional[str] = None
224+
225+
207226
@dataclass
208227
class TGIModelConfig:
209228
inference_server_address: str
@@ -279,6 +298,7 @@ def create_model_config( # noqa: C901
279298
TGIModelConfig,
280299
InferenceEndpointModelConfig,
281300
DummyModelConfig,
301+
VLLMModelConfig,
282302
]:
283303
"""
284304
Create a model configuration based on the provided arguments.
@@ -313,6 +333,9 @@ def create_model_config( # noqa: C901
313333
if model_args.pop("dummy", False):
314334
return DummyModelConfig(**model_args)
315335

336+
if model_args.pop("vllm", False):
337+
return VLLMModelConfig(**model_args)
338+
316339
model_args["accelerator"] = accelerator
317340
model_args["use_chat_template"] = use_chat_template
318341
model_args["compile"] = bool(model_args["compile"]) if "compile" in model_args else False

src/lighteval/models/model_loader.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
BaseModelConfig,
3434
DeltaModelConfig,
3535
DummyModelConfig,
36-
EnvConfig,
3736
InferenceEndpointModelConfig,
3837
InferenceModelConfig,
3938
TGIModelConfig,
39+
VLLMModelConfig,
4040
)
4141
from lighteval.models.tgi_model import ModelClient
42-
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available
42+
from lighteval.models.vllm_model import VLLMModel
43+
from lighteval.utils.imports import NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, is_tgi_available, is_vllm_available
44+
from lighteval.utils.utils import EnvConfig
4345

4446

4547
def load_model( # noqa: C901
@@ -50,6 +52,7 @@ def load_model( # noqa: C901
5052
TGIModelConfig,
5153
InferenceEndpointModelConfig,
5254
DummyModelConfig,
55+
VLLMModelConfig,
5356
],
5457
env_config: EnvConfig,
5558
) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]:
@@ -81,6 +84,9 @@ def load_model( # noqa: C901
8184
if isinstance(config, DummyModelConfig):
8285
return load_dummy_model(config=config, env_config=env_config)
8386

87+
if isinstance(config, VLLMModelConfig):
88+
return load_model_with_accelerate_or_default(config=config, env_config=env_config)
89+
8490

8591
def load_model_with_tgi(config: TGIModelConfig):
8692
if not is_tgi_available():
@@ -106,6 +112,11 @@ def load_model_with_accelerate_or_default(
106112
model = AdapterModel(config=config, env_config=env_config)
107113
elif isinstance(config, DeltaModelConfig):
108114
model = DeltaModel(config=config, env_config=env_config)
115+
elif isinstance(config, VLLMModelConfig):
116+
if not is_vllm_available():
117+
raise ImportError(NO_VLLM_ERROR_MSG)
118+
model = VLLMModel(config=config, env_config=env_config)
119+
return model
109120
else:
110121
model = BaseModel(config=config, env_config=env_config)
111122

0 commit comments

Comments
 (0)