Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def __init__(
else:
raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")

if kwargs:
logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")

self.post_init()

@property
Expand All @@ -286,6 +289,9 @@ def load_in_4bit(self):

@load_in_4bit.setter
def load_in_4bit(self, value: bool):
if not isinstance(value, bool):
raise ValueError("load_in_4bit must be a boolean")

if self.load_in_8bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_4bit = value
Expand All @@ -296,6 +302,9 @@ def load_in_8bit(self):

@load_in_8bit.setter
def load_in_8bit(self, value: bool):
if not isinstance(value, bool):
raise ValueError("load_in_4bit must be a boolean")

if self.load_in_4bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = value
Expand All @@ -304,6 +313,12 @@ def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.load_in_4bit, bool):
raise ValueError("load_in_4bit must be a boolean")

if not isinstance(self.load_in_8bit, bool):
raise ValueError("load_in_8bit must be a boolean")

if not isinstance(self.llm_int8_threshold, float):
raise ValueError("llm_int8_threshold must be a float")

Expand Down