diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index e038768b97f6..f340c1db8237 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -156,7 +156,10 @@ def _replace_with_bnb_linear( if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): with init_empty_weights(): if isinstance(module, Conv1D): in_features, out_features = module.weight.shape