@@ -94,6 +94,14 @@ def update(
9494            _pad_and_append_at_idx_ (self .k_observers , layer_idx , k_observer )
9595            _pad_and_append_at_idx_ (self .v_observers , layer_idx , v_observer )
9696
97+         # reshape for per channel scenario 
98+         num_heads  =  key_states .shape [1 ]
99+         head_dim  =  key_states .shape [- 1 ]
100+         # from [batch_size, num_heads, seq_len - residual_length, head_dim] 
101+         # to [batch_size, seq_len - residual_length, num_heads * head_dim] 
102+         key_states  =  key_states .transpose (1 , 2 ).flatten (2 )
103+         value_states  =  value_states .transpose (1 , 2 ).flatten (2 )
104+ 
97105        q_key_states  =  self ._quantize (
98106            key_states .contiguous (), KVCacheScaleType .KEY , layer_idx 
99107        )
@@ -106,6 +114,14 @@ def update(
106114            q_value_states , KVCacheScaleType .VALUE , layer_idx 
107115        )
108116
117+         # reshape for per channel scenario 
118+         # from [batch_size, seq_len - residual_length, num_heads * head_dim] 
119+         # to [batch_size, num_heads, seq_len - residual_length, head_dim] 
120+         qdq_key_states  =  qdq_key_states .view (
121+             qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
122+         qdq_value_states  =  qdq_value_states .view (
123+             qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
124+ 
109125        keys_to_return , values_to_return  =  qdq_key_states , qdq_value_states 
110126
111127        return  keys_to_return , values_to_return 
@@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
155171            zps  =  self .v_zps 
156172
157173        scale , zp  =  observer (tensor )
158-         _pad_and_append_at_idx_ (scales , layer_idx , scale )
159-         _pad_and_append_at_idx_ (zps , layer_idx , zp )
174+         _pad_and_append_at_idx_ (scales , layer_idx , scale . squeeze () )
175+         _pad_and_append_at_idx_ (zps , layer_idx , zp . squeeze () )
160176
161177        q_tensor  =  quantize (
162178            x = tensor ,
0 commit comments