Skip to content
7 changes: 6 additions & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,12 @@ def get_size_per_token(self):
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num

return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)

def init_kv_buffer(self):
return torch.empty(
Expand Down
Loading