@@ -94,13 +94,14 @@ def update(
9494 _pad_and_append_at_idx_ (self .k_observers , layer_idx , k_observer )
9595 _pad_and_append_at_idx_ (self .v_observers , layer_idx , v_observer )
9696
97- # reshape for per channel scenario
98- num_heads = key_states .shape [1 ]
99- head_dim = key_states .shape [- 1 ]
100- # from [batch_size, num_heads, seq_len - residual_length, head_dim]
101- # to [batch_size, seq_len - residual_length, num_heads * head_dim]
102- key_states = key_states .transpose (1 , 2 ).flatten (2 )
103- value_states = value_states .transpose (1 , 2 ).flatten (2 )
97+ if key_states .dim () == 4 :
98+ # reshape for per channel scenario
99+ num_heads = key_states .shape [1 ]
100+ head_dim = key_states .shape [- 1 ]
101+ # from [batch_size, num_heads, seq_len - residual_length, head_dim]
102+ # to [batch_size, seq_len - residual_length, num_heads * head_dim]
103+ key_states = key_states .transpose (1 , 2 ).flatten (2 )
104+ value_states = value_states .transpose (1 , 2 ).flatten (2 )
104105
105106 q_key_states = self ._quantize (
106107 key_states .contiguous (), KVCacheScaleType .KEY , layer_idx
@@ -114,13 +115,18 @@ def update(
114115 q_value_states , KVCacheScaleType .VALUE , layer_idx
115116 )
116117
117- # reshape for per channel scenario
118- # from [batch_size, seq_len - residual_length, num_heads * head_dim]
119- # to [batch_size, num_heads, seq_len - residual_length, head_dim]
120- qdq_key_states = qdq_key_states .view (
121- qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
122- qdq_value_states = qdq_value_states .view (
123- qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
118+ if key_states .dim () == 4 :
119+ # reshape for per channel scenario
120+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
121+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
122+ qdq_key_states = qdq_key_states .view (
123+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ],
124+ num_heads , head_dim
125+ ).transpose (1 , 2 ).contiguous ()
126+ qdq_value_states = qdq_value_states .view (
127+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ],
128+ num_heads , head_dim
129+ ).transpose (1 , 2 ).contiguous ()
124130
125131 keys_to_return , values_to_return = qdq_key_states , qdq_value_states
126132
0 commit comments