diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py index 4820aa2a71f..d78e16b9c5a 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py @@ -120,11 +120,17 @@ def build_attn_metadata( if num_tokens <= 0: return {}, torch.empty((0,), dtype=torch.int64, device=self.device), {} + # Convert tensors to Python lists once to avoid per-element .item() + # calls (each .item() on a GPU tensor forces a GPU→CPU sync). + query_lens_list = query_lens_i32[:num_reqs].tolist() + seq_lens_list = seq_lens_i32[:num_reqs].tolist() + block_table_cpu = self._block_table[:num_reqs].cpu().tolist() + # positions: for each request i, emit positions [seq_len-query_len .. seq_len-1] pos_list: list[torch.Tensor] = [] for i in range(num_reqs): - ql = int(query_lens_i32[i].item()) - sl = int(seq_lens_i32[i].item()) + ql = query_lens_list[i] + sl = seq_lens_list[i] start = sl - ql pos_list.append(torch.arange(start, sl, dtype=torch.int64)) positions_cpu = torch.cat(pos_list, dim=0) @@ -134,13 +140,13 @@ def build_attn_metadata( slot_mapping = torch.empty((num_tokens,), dtype=torch.int64, device="cpu") cursor = 0 for i in range(num_reqs): - ql = int(query_lens_i32[i].item()) - sl = int(seq_lens_i32[i].item()) + ql = query_lens_list[i] + sl = seq_lens_list[i] start = sl - ql for p in range(start, sl): block_idx = p // self.block_size offset = p % self.block_size - block_id = int(self._block_table[i, block_idx].item()) + block_id = block_table_cpu[i][block_idx] slot_mapping[cursor] = block_id * self.block_size + offset cursor += 1