Skip to content

Commit 57bee27

Browse files
committed
support attn_head
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 241c310 commit 57bee27

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/llmcompressor/observers/helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def _flatten_weight(
8585
.unsqueeze(0)
8686
)
8787

88+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
89+
raise ValueError("Attention head quantization cannot be applied to weights")
90+
8891
assert False, f"Unknown strategy {args.strategy}"
8992

9093

@@ -111,6 +114,9 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
111114
if args.strategy == QuantizationStrategy.BLOCK:
112115
raise ValueError("Block quantization cannot be applied to activations")
113116

117+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
118+
raise ValueError("Attention head quantization cannot be applied to activations")
119+
114120
assert False, f"Unknown strategy {args.strategy}"
115121

116122

@@ -133,4 +139,8 @@ def _flatten_attention(value: torch.Tensor, args: QuantizationArgs):
133139
if args.strategy == QuantizationStrategy.BLOCK:
134140
raise ValueError("Block quantization cannot be applied to attention")
135141

142+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
143+
# (batch_size * seq_len, num_heads, 1, 1, head_dim)
144+
return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)
145+
136146
assert False, f"Unknown strategy {args.strategy}"

0 commit comments

Comments
 (0)