Skip to content
6 changes: 3 additions & 3 deletions deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def __init__(self,
self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
torch.nn.init.xavier_uniform_(w)

if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
self.base_weight = QuantizedParameter(w,
requires_grad=False, # quantized weights must be frozen
quantization_config=quantization_config)
else:
self.base_weight = w

Expand Down