diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index ef102f3abb43..1eaa622219f2 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -307,10 +307,16 @@ def free(self, free_index: torch.Tensor): self.free_swa(free_index) else: self.free_group.append(free_index) - assert ( - self.full_attn_allocator.available_size() <= self.full_attn_allocator.size - ) - assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size + + def _is_elastic(self, alloc): + return hasattr(alloc, "kvcached_allocator") + + if not self._is_elastic(self.full_attn_allocator): + assert ( + self.full_attn_allocator.available_size() <= self.full_attn_allocator.size + ) + if not self._is_elastic(self.swa_attn_allocator): + assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size def free_swa(self, free_index: torch.Tensor): swa_indices = self.full_to_swa_index_mapping[free_index]