Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
ref_input_batch.refresh_metadata()

_compare_objs(input_batch, ref_input_batch)


def _construct_pooling_request(req_id_suffix: int):
from vllm.pooling_params import PoolingParams

prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10, MAX_PROMPT_SIZE))
]
return CachedRequestState(
req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=None,
pooling_params=PoolingParams(task="classify"),
mm_features=[],
block_ids=([],),
generator=None,
num_computed_tokens=0,
output_token_ids=[],
)


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_pooling_prompt_lens_not_aliased(device: str):
"""Verify that prompt_lens in PoolingMetadata does not share memory
with the internal num_prompt_tokens pinned buffer. Guards against possible
non-determinism in pooling metadata due to mutations to the internal buffer.
"""
batch_size = 4
input_batch = InputBatch(
max_num_reqs=batch_size * 2,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=batch_size * (MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS),
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)

reqs = []
# Add requests
for i in range(batch_size):
req = _construct_pooling_request(i)
input_batch.add_request(req)
reqs.append(req)
input_batch.refresh_metadata()

# prompt_lens must be a snapshot
metadata = input_batch.get_pooling_metadata()
prompt_lens_snapshot = metadata.prompt_lens.clone()

# Mutate the internal buffer (simulates next batch adding new requests)
input_batch.num_prompt_tokens_cpu_tensor.fill_(999)

# prompt_lens must be unaffected by the mutation
assert torch.equal(metadata.prompt_lens, prompt_lens_snapshot), (
"prompt_lens shares memory with internal pinned buffer; "
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
)
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def get_pooling_metadata(self) -> PoolingMetadata:
pooling_states = self.get_pooling_states()

return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
pooling_states=pooling_states,
Expand Down
Loading