Skip to content

Commit 470fa92

Browse files
authored
Use blfoat16 as default for vllm models. (#638)
Using `None` raises a big because vllm does not accept `None` as input for dtype
1 parent 2583a0a commit 470fa92

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/lighteval/models/vllm/vllm_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
GenerativeResponse,
3838
LoglikelihoodResponse,
3939
)
40-
from lighteval.models.utils import _get_dtype, _simplify_name
40+
from lighteval.models.utils import _simplify_name
4141
from lighteval.tasks.requests import (
4242
GreedyUntilRequest,
4343
LoglikelihoodRequest,
@@ -78,7 +78,7 @@ class VLLMModelConfig:
7878
pretrained: str
7979
gpu_memory_utilization: float = 0.9 # lower this if you are running out of memory
8080
revision: str = "main" # revision of the model
81-
dtype: str | None = None
81+
dtype: str = "bfloat16"
8282
tensor_parallel_size: int = 1 # how many GPUs to use for tensor parallelism
8383
pipeline_parallel_size: int = 1 # how many GPUs to use for pipeline parallelism
8484
data_parallel_size: int = 1 # how many GPUs to use for data parallelism
@@ -128,7 +128,7 @@ def __init__(
128128

129129
self.model_name = _simplify_name(config.pretrained)
130130
self.model_sha = "" # config.get_model_sha()
131-
self.precision = _get_dtype(config.dtype, config=self._config)
131+
self.precision = config.dtype
132132

133133
self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
134134
self.pairwise_tokenization = config.pairwise_tokenization

0 commit comments

Comments
 (0)