diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 6ff20557336..c9c8bd0b1b5 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -69,6 +69,9 @@ def __init__( self.decode_threshold, dtype=torch.uint8, device=device) + self.block_size = (self.block_size * + self.cp_virtual_block_size) // np.gcd( + self.block_size, self.cp_virtual_block_size) def build( self,