Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@ def update(
v_observer_name, quantization_args=self.quantization_args
)

self.k_observers.append(k_observer)
self.v_observers.append(v_observer)
# NOTE: User may ignore some layers in configuration,
# meaning len(self.k_observers) <= layer_idx-1
# Must account for that case by padding list so that
# index of lists corresponds to layer_idx
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

q_key_states = self._quantize(
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
Expand Down Expand Up @@ -151,12 +155,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
zps = self.v_zps

scale, zp = observer(tensor)
if len(scales) <= layer_idx:
scales.append(scale)
zps.append(zp)
else:
scales[layer_idx] = scale
zps[layer_idx] = scale
_pad_and_append_at_idx_(scales, layer_idx, scale)
_pad_and_append_at_idx_(zps, layer_idx, zp)

q_tensor = quantize(
x=tensor,
Expand Down Expand Up @@ -185,3 +185,24 @@ def _dequantize(self, qtensor, kv_type, layer_idx):
args=self.quantization_args,
)
return qdq_tensor


# NOTE: Using _ suffix to denote l is modified in place
def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
"""
Append value val to list lst at index idx, right padding if necessary
Needed because user may ignore some layers in configuration, meaning
len(lst) <= idx-1

>>> _pad_and_append_at_idx_([0,1,2], 5, 5)
[0, 1, 2, None, None, 5]
>>> _pad_and_append_at_idx_([0,1,2], 3, 8)
[0, 1, 2, 8]
>>> _pad_and_append_at_idx_([0,1,2], 1, 5)
[0, 5, 2]
"""
num_to_pad = idx - len(lst) + 1
if num_to_pad > 0:
lst += [None] * num_to_pad
lst[idx] = val
return lst