diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1b5f8b5704..cc7bd11922 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -2004,9 +2004,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Unpack attention weights if needed self.make_attention_unpacked(layer_id, attention, root_input, **kwargs) + + # Get dtype used for MatMul ops + q_dtype = getattr(attention.q_proj, "weight", getattr(attention.q_proj, "bits", None)) + k_dtype = getattr(attention.k_proj, "weight", getattr(attention.k_proj, "bits", None)) + v_dtype = getattr(attention.v_proj, "weight", getattr(attention.v_proj, "bits", None)) + qkv_dtype_equal = getattr(q_dtype, "dtype", q_dtype) == getattr(k_dtype, "dtype", k_dtype) == getattr(v_dtype, "dtype", v_dtype) # Make MatMul nodes - if self.attention_attrs["use_packed_matmul"]: + if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal: # Combine 3 MatMuls into 1 packed MatMul qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input) @@ -2028,7 +2034,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0 any_bias_exists = q_bias_exists or k_bias_exists or v_bias_exists - if self.attention_attrs["use_packed_matmul"] and any_bias_exists: + if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal and any_bias_exists: # Combine 3 Adds into 1 packed Add qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add" self.make_packed_add(attention.q_proj.bias, attention.k_proj.bias, attention.v_proj.bias, qkv_add_name, root_input=self.attention_attrs["q_path"]) diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index 535348b233..eb9f5585f8 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -863,6 +863,22 @@ def __init__(self, module): self.pack_qzeros(temp_module) module.qzeros = temp_module.qzeros + def _load_quant_config(self, quant_attrs): + super()._load_quant_config(quant_attrs) + self.overrides = quant_attrs["config"].get("dynamic", {}) + + def get_overrides(self, layer_name): + for pattern, overrides in self.overrides.items(): + if re.match(pattern.removeprefix("+:"), layer_name): + return overrides + return {} + + def get_layer_bits(self, layer_name): + return self.get_overrides(layer_name).get("bits", self.global_bits) + + def get_layer_group_size(self, layer_name): + return self.get_overrides(layer_name).get("group_size", self.global_group_size) + class QuarkModel(QuantizedModel): def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers): super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers)