diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index 155fe25d7083..861e1f357656 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -86,11 +86,11 @@ def __init__(self, self.zero_shards = self.lora_config.base_weight_sharding self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) - torch.nn.init.xavier_uniform_(w) - if self.quantization_config is not None: assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" - self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + self.base_weight = QuantizedParameter(w, + requires_grad=False, # quantized weights must be frozen + quantization_config=quantization_config) else: self.base_weight = w