diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index f18754779..0c2f21971 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -13,13 +13,6 @@ def _div_up(x, n): return (x + n - 1) // n -def _last_block_size(history_len: int, block_size: int): - """last block size.""" - last = history_len % block_size - last = last if last != 0 else block_size - return last - - def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int): """num blocks to free.""" if seq.history_len <= window_size: @@ -47,30 +40,17 @@ def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int, window_size: int): f'but get window_size = {window_size}') self.window_size = window_size - @classmethod - def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0): + def num_required_blocks(self, obj: SchedulerSequence, prealloc_size: int = 0): """get num required blocks.""" - def __num_req_seq(seq: SchedulerSequence): - """get num required seq blocks.""" - block_size = seq.block_size - lb_tokens = cls.last_block_size(seq) - lb_remain_tokens = 0 - if len(seq.logical_blocks) > 0: - lb_remain_tokens = block_size - lb_tokens - num_input_tokens = seq.num_token_ids + prealloc_size - num_req_tokens = max(0, num_input_tokens - lb_remain_tokens) - return _div_up(num_req_tokens, block_size) - - return __num_req_seq(obj) - - @classmethod - def last_block_size(cls, seq: SchedulerSequence) -> int: - """get last block size.""" - num_blocks = len(seq.logical_blocks) - if num_blocks == 0: - return 0 - return _last_block_size(seq.history_len, seq.block_size) + # blocks is not enough + if obj.num_history_ids < self.window_size: + return super().num_required_blocks(obj, prealloc_size) + + # we only keep history less than window_size + num_tokens = self.window_size + obj.num_token_ids + prealloc_size + num_all_blocks = _div_up(num_tokens, obj.block_size) + return max(0, num_all_blocks - len(obj.logical_blocks)) def can_allocate(self, msg: SchedulerSequence, prealloc_size: int = 0): """Return if physical block can be allocated for given message."""