Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,9 +2004,15 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):

# Unpack attention weights if needed
self.make_attention_unpacked(layer_id, attention, root_input, **kwargs)

# Get dtype used for MatMul ops
q_dtype = getattr(attention.q_proj, "weight", getattr(attention.q_proj, "bits", None))
k_dtype = getattr(attention.k_proj, "weight", getattr(attention.k_proj, "bits", None))
v_dtype = getattr(attention.v_proj, "weight", getattr(attention.v_proj, "bits", None))
qkv_dtype_equal = getattr(q_dtype, "dtype", q_dtype) == getattr(k_dtype, "dtype", k_dtype) == getattr(v_dtype, "dtype", v_dtype)

# Make MatMul nodes
if self.attention_attrs["use_packed_matmul"]:
if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal:
# Combine 3 MatMuls into 1 packed MatMul
qkv_matmul_basename = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul"
qkv_matmul_name = self.make_packed_matmul(attention.q_proj, attention.k_proj, attention.v_proj, qkv_matmul_basename, root_input)
Expand All @@ -2028,7 +2034,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
v_bias_exists = attention.v_proj.bias is not None and torch.count_nonzero(attention.v_proj.bias) > 0
any_bias_exists = q_bias_exists or k_bias_exists or v_bias_exists

if self.attention_attrs["use_packed_matmul"] and any_bias_exists:
if self.attention_attrs["use_packed_matmul"] and qkv_dtype_equal and any_bias_exists:
# Combine 3 Adds into 1 packed Add
qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add"
self.make_packed_add(attention.q_proj.bias, attention.k_proj.bias, attention.v_proj.bias, qkv_add_name, root_input=self.attention_attrs["q_path"])
Expand Down
16 changes: 16 additions & 0 deletions src/python/py/models/quantized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,22 @@ def __init__(self, module):
self.pack_qzeros(temp_module)
module.qzeros = temp_module.qzeros

def _load_quant_config(self, quant_attrs):
super()._load_quant_config(quant_attrs)
self.overrides = quant_attrs["config"].get("dynamic", {})

def get_overrides(self, layer_name):
for pattern, overrides in self.overrides.items():
if re.match(pattern.removeprefix("+:"), layer_name):
return overrides
return {}

def get_layer_bits(self, layer_name):
return self.get_overrides(layer_name).get("bits", self.global_bits)

def get_layer_group_size(self, layer_name):
return self.get_overrides(layer_name).get("group_size", self.global_group_size)

class QuarkModel(QuantizedModel):
def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers):
super().__init__(quant_type, input_path, quant_attrs, q_size, kv_size, intermediate_size, num_layers)
Expand Down
Loading