Skip to content

Commit e9b5a0a

Browse files
authored
change priority when choosing model dtype (#263)
1 parent 5b01da2 commit e9b5a0a

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

src/lighteval/models/utils.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from transformers import AutoConfig
3030

3131

32-
def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = None) -> torch.dtype:
32+
def _get_dtype(dtype: Union[str, torch.dtype, None], config: Optional[AutoConfig] = None) -> Optional[torch.dtype]:
3333
"""
3434
Get the torch dtype based on the input arguments.
3535
@@ -41,17 +41,21 @@ def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = No
4141
torch.dtype: The torch dtype based on the input arguments.
4242
"""
4343

44-
if config is not None: # For quantized models
45-
if hasattr(config, "quantization_config"):
46-
_torch_dtype = None # must be inferred
47-
else:
48-
_torch_dtype = config.torch_dtype
49-
elif isinstance(dtype, str) and dtype not in ["auto", "4bit", "8bit"]:
50-
# Convert `str` args torch dtype: `float16` -> `torch.float16`
51-
_torch_dtype = getattr(torch, dtype)
52-
else:
53-
_torch_dtype = dtype
54-
return _torch_dtype
44+
if config is not None and hasattr(config, "quantization_config"):
45+
# must be infered
46+
return None
47+
48+
if dtype is not None:
49+
if isinstance(dtype, str) and dtype not in ["auto", "4bit", "8bit"]:
50+
# Convert `str` args torch dtype: `float16` -> `torch.float16`
51+
return getattr(torch, dtype)
52+
elif isinstance(dtype, torch.dtype):
53+
return dtype
54+
55+
if config is not None:
56+
return config.torch_dtype
57+
58+
return None
5559

5660

5761
def _simplify_name(name_or_path: str) -> str:

0 commit comments

Comments
 (0)