2929from transformers import AutoConfig
3030
3131
32- def _get_dtype (dtype : Union [str , torch .dtype ], config : Optional [AutoConfig ] = None ) -> torch .dtype :
32+ def _get_dtype (dtype : Union [str , torch .dtype , None ], config : Optional [AutoConfig ] = None ) -> Optional [ torch .dtype ] :
3333 """
3434 Get the torch dtype based on the input arguments.
3535
@@ -41,17 +41,21 @@ def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[AutoConfig] = No
4141 torch.dtype: The torch dtype based on the input arguments.
4242 """
4343
44- if config is not None : # For quantized models
45- if hasattr (config , "quantization_config" ):
46- _torch_dtype = None # must be inferred
47- else :
48- _torch_dtype = config .torch_dtype
49- elif isinstance (dtype , str ) and dtype not in ["auto" , "4bit" , "8bit" ]:
50- # Convert `str` args torch dtype: `float16` -> `torch.float16`
51- _torch_dtype = getattr (torch , dtype )
52- else :
53- _torch_dtype = dtype
54- return _torch_dtype
44+ if config is not None and hasattr (config , "quantization_config" ):
45+ # must be infered
46+ return None
47+
48+ if dtype is not None :
49+ if isinstance (dtype , str ) and dtype not in ["auto" , "4bit" , "8bit" ]:
50+ # Convert `str` args torch dtype: `float16` -> `torch.float16`
51+ return getattr (torch , dtype )
52+ elif isinstance (dtype , torch .dtype ):
53+ return dtype
54+
55+ if config is not None :
56+ return config .torch_dtype
57+
58+ return None
5559
5660
5761def _simplify_name (name_or_path : str ) -> str :
0 commit comments