diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 6ea65c6944b0..c4a55c8370e0 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis ref_input_batch.refresh_metadata() _compare_objs(input_batch, ref_input_batch) + + +def _construct_pooling_request(req_id_suffix: int): + from vllm.pooling_params import PoolingParams + + prompt_token_ids = [ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(10, MAX_PROMPT_SIZE)) + ] + return CachedRequestState( + req_id=f"pool_req_{req_id_suffix}", + prompt_token_ids=prompt_token_ids, + sampling_params=None, + pooling_params=PoolingParams(task="classify"), + mm_features=[], + block_ids=([],), + generator=None, + num_computed_tokens=0, + output_token_ids=[], + ) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_pooling_prompt_lens_not_aliased(device: str): + """Verify that prompt_lens in PoolingMetadata does not share memory + with the internal num_prompt_tokens pinned buffer. Guards against possible + non-determinism in pooling metadata due to mutations to the internal buffer. + """ + batch_size = 4 + input_batch = InputBatch( + max_num_reqs=batch_size * 2, + max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS, + max_num_batched_tokens=batch_size * (MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS), + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=VOCAB_SIZE, + block_sizes=[16], + kernel_block_sizes=[16], + is_pooling_model=True, + ) + + reqs = [] + # Add requests + for i in range(batch_size): + req = _construct_pooling_request(i) + input_batch.add_request(req) + reqs.append(req) + input_batch.refresh_metadata() + + # prompt_lens must be a snapshot + metadata = input_batch.get_pooling_metadata() + prompt_lens_snapshot = metadata.prompt_lens.clone() + + # Mutate the internal buffer (simulates next batch adding new requests) + input_batch.num_prompt_tokens_cpu_tensor.fill_(999) + + # prompt_lens must be unaffected by the mutation + assert torch.equal(metadata.prompt_lens, prompt_lens_snapshot), ( + "prompt_lens shares memory with internal pinned buffer; " + "mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. " + f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}" + ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 13941be88dfd..fb7795e04740 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -892,7 +892,7 @@ def get_pooling_metadata(self) -> PoolingMetadata: pooling_states = self.get_pooling_states() return PoolingMetadata( - prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), + prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, pooling_states=pooling_states,