Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,17 @@ def build_attn_metadata(
if num_tokens <= 0:
return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {}

# Convert tensors to Python lists once to avoid per-element .item()
# calls (each .item() on a GPU tensor forces a GPU→CPU sync).
query_lens_list = query_lens_i32[:num_reqs].tolist()
seq_lens_list = seq_lens_i32[:num_reqs].tolist()
block_table_cpu = self._block_table[:num_reqs].cpu().tolist()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid full block-table host copy in hot decode path

In _LocalPredictorKVCache.build_attn_metadata, this line materializes the entire num_reqs × blocks_per_seq block table as Python lists on every call, even though the inner loop only needs one block_idx per generated token; during decode (decode_logits sets query_lens to ones), that turns per-step work from roughly O(num_reqs) into O(num_reqs * blocks_per_seq), where blocks_per_seq is derived from configured max_seq_len. For long-context configs this extra device→host copy plus Python-int conversion can dominate token-step latency and regress throughput versus the previous indexed access.

Useful? React with 👍 / 👎.


# positions: for each request i, emit positions [seq_len-query_len .. seq_len-1]
pos_list: list[torch.Tensor] = []
for i in range(num_reqs):
ql = int(query_lens_i32[i].item())
sl = int(seq_lens_i32[i].item())
ql = query_lens_list[i]
sl = seq_lens_list[i]
start = sl - ql
pos_list.append(torch.arange(start, sl, dtype=torch.int64))
positions_cpu = torch.cat(pos_list, dim=0)
Expand All @@ -134,13 +140,13 @@ def build_attn_metadata(
slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu")
cursor = 0
for i in range(num_reqs):
ql = int(query_lens_i32[i].item())
sl = int(seq_lens_i32[i].item())
ql = query_lens_list[i]
sl = seq_lens_list[i]
start = sl - ql
for p in range(start, sl):
block_idx = p // self.block_size
offset = p % self.block_size
block_id = int(self._block_table[i, block_idx].item())
block_id = block_table_cpu[i][block_idx]
slot_mapping[cursor] = block_id * self.block_size + offset
cursor += 1

Expand Down