diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cc33b3319712..e9efd7fbf6b1 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -136,7 +136,7 @@ def compute_num_computed_tokens(self) -> torch.Tensor: """Compute num_computed_tokens on device (seq_lens - query_lens).""" if self._num_computed_tokens_cache is None: query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] - self._num_computed_tokens_cache = self.seq_lens - query_lens + self._num_computed_tokens_cache = self.seq_lens.cpu() - query_lens.cpu() return self._num_computed_tokens_cache # TODO(lucas): remove once we have FULL-CG spec-decode support