diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 322be6503b9..fc666fd40a1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -136,7 +136,8 @@ def __init__( self._num_pages = max( self.max_batch_size, (self.max_num_tokens) // self.page_size # floored number of pages - + (self.max_num_tokens % self.page_size > 0) * self.max_batch_size, # +1 per sequence + + (self.max_num_tokens / self.max_batch_size % self.page_size > 0) # check for overflow + * self.max_batch_size, # +1 page per sequence if overflow is required ) # sanity check assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"