diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 2f168888e787..0e2395b3138c 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -278,6 +278,9 @@ def __init__( else: raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + if kwargs: + logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") + self.post_init() @property @@ -286,6 +289,9 @@ def load_in_4bit(self): @load_in_4bit.setter def load_in_4bit(self, value: bool): + if not isinstance(value, bool): + raise ValueError("load_in_4bit must be a boolean") + if self.load_in_8bit and value: raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") self._load_in_4bit = value @@ -296,6 +302,9 @@ def load_in_8bit(self): @load_in_8bit.setter def load_in_8bit(self, value: bool): + if not isinstance(value, bool): + raise ValueError("load_in_8bit must be a boolean") + if self.load_in_4bit and value: raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") self._load_in_8bit = value @@ -304,6 +313,12 @@ def post_init(self): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ + if not isinstance(self.load_in_4bit, bool): + raise ValueError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise ValueError("load_in_8bit must be a boolean") + if not isinstance(self.llm_int8_threshold, float): raise ValueError("llm_int8_threshold must be a float")