Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,13 +943,14 @@ def _load_state_dict_into_meta_model(
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
Comment on lines +946 to +947
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 I knew it!
It was a bit too specific

if old_param is None:
break

if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down
20 changes: 20 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def test_int4wo_offload(self):

self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)

def test_int8_dynamic_activation_int8_weight_quant(self):
"""
Simple LLM model testing int8_dynamic_activation_int8_weight
"""
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")

# Note: we quantize the bfloat16 model on the fly to int4
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


if __name__ == "__main__":
unittest.main()