diff --git a/tests/models/language/pooling/test_tokwise_pooler_batching.py b/tests/models/language/pooling/test_tokwise_pooler_batching.py new file mode 100644 index 000000000000..055c1e5628eb --- /dev/null +++ b/tests/models/language/pooling/test_tokwise_pooler_batching.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + +import numpy as np +import torch +import torch.nn as nn + +import vllm.model_executor.layers.pooler.tokwise.methods as tokwise_methods +from vllm.model_executor.layers.pooler.tokwise.heads import ( + TokenEmbeddingPoolerHead, +) +from vllm.model_executor.layers.pooler.tokwise.methods import AllPool +from vllm.model_executor.layers.pooler.tokwise.poolers import TokenPooler +from vllm.pooling_params import PoolingParams +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates + + +class CountingLinear(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=False) + self.call_count = 0 + self.input_shapes: list[tuple[int, ...]] = [] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.call_count += 1 + self.input_shapes.append(tuple(x.shape)) + return self.linear(x) + + +def _patch_chunked_prefill(monkeypatch, enabled: bool) -> None: + monkeypatch.setattr( + tokwise_methods, + "get_current_vllm_config", + lambda: SimpleNamespace( + scheduler_config=SimpleNamespace(enable_chunked_prefill=enabled) + ), + ) + + +def _build_pooling_metadata( + *, + prompt_lens: list[int], + pooling_params: list[PoolingParams], + seq_lens: list[int] | None = None, + scheduled_lens: list[int] | None = None, +) -> PoolingMetadata: + prompt_lens_cpu = torch.tensor(prompt_lens, dtype=torch.int64) + if seq_lens is None: + seq_lens = prompt_lens + if scheduled_lens is None: + scheduled_lens = seq_lens + + seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) + query_start_loc_cpu = torch.tensor( + [0, *np.cumsum(scheduled_lens, dtype=np.int64)], + dtype=torch.int64, + ) + + metadata = PoolingMetadata( + prompt_lens=prompt_lens_cpu, + prompt_token_ids=None, + prompt_token_ids_cpu=None, + pooling_params=pooling_params, + pooling_states=[PoolingStates() for _ in pooling_params], + ) + metadata.build_pooling_cursor( + num_scheduled_tokens_np=np.asarray(scheduled_lens, dtype=np.int64), + seq_lens_cpu=seq_lens_cpu, + device=torch.device("cpu"), + query_start_loc_gpu=query_start_loc_cpu, + ) + return metadata + + +def test_token_embed_pooler_projects_flat_batch_once(monkeypatch): + _patch_chunked_prefill(monkeypatch, enabled=False) + + hidden_size = 4 + lengths = [2, 3, 1] + hidden_states = torch.randn(sum(lengths), hidden_size) + projector = CountingLinear(hidden_size, 5) + + pooling_params = [ + PoolingParams(task="token_embed", dimensions=5, use_activation=False), + PoolingParams(task="token_embed", dimensions=3, use_activation=True), + PoolingParams(task="token_embed", dimensions=4, use_activation=False), + ] + pooling_metadata = _build_pooling_metadata( + prompt_lens=lengths, + pooling_params=pooling_params, + ) + pooler = TokenPooler( + pooling=AllPool(), + head=TokenEmbeddingPoolerHead( + projector=projector, + activation=torch.tanh, + ), + ) + + outputs = pooler(hidden_states, pooling_metadata) + + assert projector.call_count == 1 + assert projector.input_shapes == [(sum(lengths), hidden_size)] + + expected_outputs = [] + offset = 0 + for length, pooling_param in zip(lengths, pooling_params): + chunk = hidden_states[offset : offset + length] + embeddings = projector.linear(chunk) + embeddings = embeddings[..., : pooling_param.dimensions] + if pooling_param.use_activation: + embeddings = torch.tanh(embeddings) + expected_outputs.append(embeddings) + offset += length + + assert len(outputs) == len(expected_outputs) + for output, expected in zip(outputs, expected_outputs): + assert output is not None + torch.testing.assert_close(output, expected) + + +@torch.inference_mode() +def test_token_embed_pooler_projects_uniform_postprocess_once(monkeypatch): + _patch_chunked_prefill(monkeypatch, enabled=False) + + hidden_size = 4 + lengths = [2, 2] + hidden_states = torch.randn(sum(lengths), hidden_size) + projector = CountingLinear(hidden_size, 6) + + pooling_params = [ + PoolingParams(task="token_embed", dimensions=4, use_activation=True), + PoolingParams(task="token_embed", dimensions=4, use_activation=True), + ] + pooling_metadata = _build_pooling_metadata( + prompt_lens=lengths, + pooling_params=pooling_params, + ) + pooler = TokenPooler( + pooling=AllPool(), + head=TokenEmbeddingPoolerHead( + projector=projector, + activation=torch.tanh, + ), + ) + + outputs = pooler(hidden_states, pooling_metadata) + + assert projector.call_count == 1 + assert projector.input_shapes == [(sum(lengths), hidden_size)] + + projected = torch.tanh(projector.linear(hidden_states)[..., :4]) + expected_outputs = [projected[:2], projected[2:]] + assert len(outputs) == len(expected_outputs) + for output, expected in zip(outputs, expected_outputs): + assert output is not None + torch.testing.assert_close(output, expected) + + +@torch.inference_mode() +def test_token_embed_pooler_batches_finished_chunked_outputs_once(monkeypatch): + _patch_chunked_prefill(monkeypatch, enabled=True) + + hidden_size = 4 + current_chunk_lens = [2, 2, 3] + prompt_lens = [4, 5, 3] + seq_lens = [4, 3, 3] + hidden_states = torch.randn(sum(current_chunk_lens), hidden_size) + projector = CountingLinear(hidden_size, 6) + + pooling_params = [ + PoolingParams(task="token_embed", dimensions=4, use_activation=True), + PoolingParams(task="token_embed", dimensions=3, use_activation=False), + PoolingParams(task="token_embed", dimensions=4, use_activation=True), + ] + pooling_metadata = _build_pooling_metadata( + prompt_lens=prompt_lens, + pooling_params=pooling_params, + seq_lens=seq_lens, + scheduled_lens=current_chunk_lens, + ) + + prev_req0 = torch.randn(2, hidden_size) + prev_req1 = torch.randn(1, hidden_size) + pooling_metadata.pooling_states[0].hidden_states_cache.append(prev_req0) + pooling_metadata.pooling_states[1].hidden_states_cache.append(prev_req1) + + pooler = TokenPooler( + pooling=AllPool(), + head=TokenEmbeddingPoolerHead( + projector=projector, + activation=torch.tanh, + ), + ) + + outputs = pooler(hidden_states, pooling_metadata) + + req0 = torch.concat([prev_req0, hidden_states[:2]], dim=0) + req2 = hidden_states[4:] + + assert projector.call_count == 1 + assert projector.input_shapes == [(req0.shape[0] + req2.shape[0], hidden_size)] + + expected0 = torch.tanh(projector.linear(req0)[..., :4]) + expected2 = torch.tanh(projector.linear(req2)[..., :4]) + + assert outputs[0] is not None + torch.testing.assert_close(outputs[0], expected0) + assert outputs[1] is None + assert outputs[2] is not None + torch.testing.assert_close(outputs[2], expected2) diff --git a/vllm/model_executor/layers/pooler/tokwise/heads.py b/vllm/model_executor/layers/pooler/tokwise/heads.py index 0377a86755ae..720e9c8ccbbf 100644 --- a/vllm/model_executor/layers/pooler/tokwise/heads.py +++ b/vllm/model_executor/layers/pooler/tokwise/heads.py @@ -12,7 +12,7 @@ from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata -from .methods import TokenPoolingMethodOutputItem +from .methods import RaggedTokenBatch, TokenPoolingMethodOutputItem TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None @@ -66,17 +66,56 @@ def forward_chunk( if pooled_data is None: return None + embeddings = self._project_batch(pooled_data) + return self._postprocess_embeddings(embeddings, pooling_param) + + def forward_ragged( + self, + pooled_data: RaggedTokenBatch, + pooling_params: list[PoolingParams], + ) -> list[TokenPoolerHeadOutputItem]: + if pooled_data.num_items != len(pooling_params): + raise ValueError( + "pooled_data and pooling_params must have the same length: " + f"{pooled_data.num_items} != {len(pooling_params)}." + ) + + # doing projection for all tokens in the batch + embeddings = self._project_batch(pooled_data.values) + active_pooling_params = self._get_present_pooling_params( + pooled_data, pooling_params + ) + if self._has_uniform_postprocess(active_pooling_params): + if active_pooling_params: + embeddings = self._postprocess_embeddings( + embeddings, active_pooling_params[0] + ) + return pooled_data.with_values(embeddings).split() + + # can't apply the same postprocess, doing it separately + pooled_outputs = pooled_data.with_values(embeddings).split() + return [ + None + if output is None + else self._postprocess_embeddings(output, pooling_param) + for output, pooling_param in zip(pooled_outputs, pooling_params) + ] + + def _project_batch(self, pooled_data: torch.Tensor) -> torch.Tensor: if self.head_dtype is not None: pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_size] # Apply ST projector if self.projector is not None: - embeddings = self.projector(pooled_data) - else: - embeddings = pooled_data - # embeddings shape: [n_tokens, embedding_size] + return self.projector(pooled_data) + return pooled_data + def _postprocess_embeddings( + self, + embeddings: torch.Tensor, + pooling_param: PoolingParams, + ) -> torch.Tensor: # for matryoshka representation embeddings = embeddings[..., : pooling_param.dimensions] @@ -87,6 +126,33 @@ def forward_chunk( # embeddings shape: [n_tokens, embedding_size] return embeddings + def _has_uniform_postprocess(self, pooling_params: list[PoolingParams]) -> bool: + """Return whether all pooling params share the same postprocess.""" + if not pooling_params: + return True + + first_param = pooling_params[0] + first_dimensions = first_param.dimensions + first_use_activation = bool(first_param.use_activation) + return all( + param.dimensions == first_dimensions + and bool(param.use_activation) == first_use_activation + for param in pooling_params[1:] + ) + + def _get_present_pooling_params( + self, + pooled_data: RaggedTokenBatch, + pooling_params: list[PoolingParams], + ) -> list[PoolingParams]: + if pooled_data.is_none_cpu is None: + return pooling_params + return [ + pooling_param + for pooling_param, is_none in zip(pooling_params, pooled_data.is_none_cpu) + if not bool(is_none) + ] + class TokenClassifierPoolerHead(TokenPoolerHead): def __init__( diff --git a/vllm/model_executor/layers/pooler/tokwise/methods.py b/vllm/model_executor/layers/pooler/tokwise/methods.py index 9ee6e8527c9a..f6a0e19d5df9 100644 --- a/vllm/model_executor/layers/pooler/tokwise/methods.py +++ b/vllm/model_executor/layers/pooler/tokwise/methods.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Set +from dataclasses import dataclass from typing import TypeAlias import torch @@ -16,6 +17,84 @@ TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None +@dataclass +class RaggedTokenBatch: + values: torch.Tensor + cu_lengths_cpu: torch.Tensor + is_none_cpu: torch.Tensor | None = None + + @classmethod + def from_lengths( + cls, + values: torch.Tensor, + lengths_cpu: torch.Tensor, + is_none_cpu: torch.Tensor | None = None, + ) -> "RaggedTokenBatch": + if is_none_cpu is not None: + assert is_none_cpu.shape == lengths_cpu.shape, ( + "is_none_cpu must match lengths_cpu shape: " + f"{tuple(is_none_cpu.shape)} != {tuple(lengths_cpu.shape)}." + ) + return cls( + values=values, + cu_lengths_cpu=_make_cu_lengths_cpu(lengths_cpu), + is_none_cpu=is_none_cpu, + ) + + @property + def num_items(self) -> int: + return self.cu_lengths_cpu.shape[0] - 1 + + def with_values(self, values: torch.Tensor) -> "RaggedTokenBatch": + expected_num_values = int(self.cu_lengths_cpu[-1]) + if values.ndim == 0 or values.shape[0] != expected_num_values: + raise ValueError( + "values must preserve the flattened token dimension: " + f"{values.shape[0] if values.ndim > 0 else 0} " + f"!= {expected_num_values}." + ) + return RaggedTokenBatch( + values=values, + cu_lengths_cpu=self.cu_lengths_cpu, + is_none_cpu=self.is_none_cpu, + ) + + def split(self) -> list[TokenPoolingMethodOutputItem]: + outputs = list[TokenPoolingMethodOutputItem]() + cu_lengths_cpu = self.cu_lengths_cpu + is_none_cpu = self.is_none_cpu + + for i in range(self.num_items): + start = int(cu_lengths_cpu[i]) + end = int(cu_lengths_cpu[i + 1]) + if is_none_cpu is not None and bool(is_none_cpu[i]): + if start != end: + raise ValueError( + "Items materialized as None must have zero length: " + f"{start} != {end}." + ) + outputs.append(None) + continue + outputs.append(self.values[start:end]) + + return outputs + + +def _make_cu_lengths_cpu(lengths_cpu: torch.Tensor) -> torch.Tensor: + # [1, 2, 3, 4] -> [0, 1, 3, 6, 10] + lengths_cpu = lengths_cpu.to(device="cpu", dtype=torch.int64) + cu_lengths_cpu = torch.zeros( + lengths_cpu.shape[0] + 1, dtype=torch.int64, device="cpu" + ) + torch.cumsum(lengths_cpu, dim=0, out=cu_lengths_cpu[1:]) + return cu_lengths_cpu + + +TokenPoolingMethodOutput: TypeAlias = ( + RaggedTokenBatch | list[TokenPoolingMethodOutputItem] +) + + class TokenPoolingMethod(nn.Module, ABC): def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} @@ -28,7 +107,7 @@ def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[TokenPoolingMethodOutputItem]: + ) -> TokenPoolingMethodOutput: raise NotImplementedError @@ -45,40 +124,58 @@ def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[TokenPoolingMethodOutputItem]: + ) -> TokenPoolingMethodOutput: pooling_cursor = pooling_metadata.get_pooling_cursor() - hidden_states_lst = [ - hidden_states[first : last + 1] - for first, last in zip( - pooling_cursor.first_token_indices_gpu.tolist(), - pooling_cursor.last_token_indices_gpu.tolist(), + if self.enable_chunked_prefill: + hidden_states_lst = RaggedTokenBatch.from_lengths( + values=hidden_states, + lengths_cpu=pooling_cursor.num_scheduled_tokens_cpu, + ).split() + + pooling_states = pooling_metadata.pooling_states + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in + # pooling_states.hidden_states_cache + for p, hs_chunk in zip(pooling_states, hidden_states_lst): + p.hidden_states_cache.append(hs_chunk) + + # 2. once prefill is finished, flatten the finished requests into a + # ragged batch while preserving unfinished slots as None-equivalents. + lengths_cpu = torch.zeros( + len(pooling_states), dtype=torch.int64, device="cpu" ) - ] - - if not self.enable_chunked_prefill: - return hidden_states_lst - - pooling_states = pooling_metadata.pooling_states - - # If chunked_prefill is enabled - # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache - for p, hs_chunk in zip(pooling_states, hidden_states_lst): - p.hidden_states_cache.append(hs_chunk) + is_none_cpu = torch.ones( + len(pooling_states), dtype=torch.bool, device="cpu" + ) + finished_values = list[torch.Tensor]() + for i, (p, finished) in enumerate( + zip(pooling_states, pooling_cursor.is_finished()) + ): + if not finished: + continue - # 2. Once prefill is finished, send hidden_states_cache to PoolerHead - output_list = list[TokenPoolingMethodOutputItem]() - for p, finished in zip(pooling_states, pooling_cursor.is_finished()): - if finished: hidden_states_cache = p.hidden_states_cache - if len(hidden_states_cache) == 1: - output_list.append(hidden_states_cache[0]) - else: - output_list.append(torch.concat(hidden_states_cache, dim=0)) + lengths_cpu[i] = sum(chunk.shape[0] for chunk in hidden_states_cache) + is_none_cpu[i] = False + finished_values.extend(hidden_states_cache) p.clean() - else: - output_list.append(None) - return output_list + values = ( + torch.concat(finished_values, dim=0) + if finished_values + else hidden_states[:0] + ) + else: + values = hidden_states + lengths_cpu = pooling_cursor.num_scheduled_tokens_cpu + is_none_cpu = None + + return RaggedTokenBatch.from_lengths( + values=values, + lengths_cpu=lengths_cpu, + is_none_cpu=is_none_cpu, + ) class StepPool(AllPool): @@ -90,7 +187,12 @@ def forward( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> list[TokenPoolingMethodOutputItem]: - pooled_data_lst = super().forward(hidden_states, pooling_metadata) + pooled_data = super().forward(hidden_states, pooling_metadata) + pooled_data_lst = ( + pooled_data.split() + if isinstance(pooled_data, RaggedTokenBatch) + else pooled_data + ) prompt_token_ids = pooling_metadata.get_prompt_token_ids() pooling_params = pooling_metadata.pooling_params diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py index 6462a5056c55..eaff8ee801a0 100644 --- a/vllm/model_executor/layers/pooler/tokwise/poolers.py +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -28,14 +28,16 @@ TokenPoolerHeadOutputItem, ) from .methods import ( + RaggedTokenBatch, TokenPoolingMethod, + TokenPoolingMethodOutput, TokenPoolingMethodOutputItem, get_tok_pooling_method, ) TokenPoolingFn: TypeAlias = Callable[ [torch.Tensor, PoolingMetadata], - list[TokenPoolingMethodOutputItem], + TokenPoolingMethodOutput, ] TokenPoolingHeadFn: TypeAlias = Callable[ [list[TokenPoolingMethodOutputItem], PoolingMetadata], @@ -89,6 +91,12 @@ def forward( pooling_metadata: PoolingMetadata, ) -> TokenPoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) + if isinstance(pooled_data, RaggedTokenBatch): + if isinstance(self.head, TokenEmbeddingPoolerHead): + return self.head.forward_ragged( + pooled_data, pooling_metadata.pooling_params + ) + pooled_data = pooled_data.split() if self.head is not None: pooled_data = self.head(pooled_data, pooling_metadata) return pooled_data