Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx(
invocations_result["relevance_score"], rel=0.01)
invocations_result["relevance_score"], rel=0.05)
14 changes: 3 additions & 11 deletions tests/models/language/pooling/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,9 @@ def v1(run_with_both_engines):
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param(
"BAAI/bge-base-en-v1.5",
marks=[
# CPU only supports V1
pytest.mark.core_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder]
Expand Down
8 changes: 8 additions & 0 deletions tests/models/language/pooling/test_jina.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
]


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,8 @@ def _set_default_args_v1(self, usage_context: UsageContext,

if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need sys.maxsize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just because self.max_num_batched_tokens can be unset, in this case the min will take the value default_max_num_seqs[usage_context]. It's just to avoid writing an if.


logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)
Expand Down
18 changes: 7 additions & 11 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -60,7 +59,6 @@ def __init__(self, config: BertConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -119,7 +117,6 @@ def forward(
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor,
return hidden_states


@support_torch_compile
class BertModel(nn.Module, SupportsQuant):

is_pooling_model = True
Expand Down Expand Up @@ -368,13 +366,9 @@ def forward(
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)

def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Expand Down Expand Up @@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
Expand All @@ -474,11 +468,13 @@ def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

Expand Down
85 changes: 64 additions & 21 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers import RobertaConfig

from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -51,33 +52,12 @@ def __init__(self, config: RobertaConfig):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)

return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
Expand Down Expand Up @@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id

self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
Expand Down Expand Up @@ -216,6 +223,9 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask

return incremental_indices.long() + padding_idx


def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:

seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))

if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())

offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length
Loading