Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/model_executor/layers/pooler/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar

import torch

from vllm.pooling_params import PoolingParams

_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])

ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]


@dataclass(frozen=True)
Expand All @@ -24,4 +29,4 @@ def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids


__all__ = ["ClassifierFn", "PoolingParamsUpdate"]
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]
88 changes: 41 additions & 47 deletions vllm/model_executor/layers/pooler/seqwise/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import torch
import torch.nn as nn

from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata

Expand All @@ -38,17 +31,17 @@ def forward(


class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()

# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config

self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype

self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}
Expand All @@ -65,7 +58,8 @@ def forward(
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]

pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)

# Apply ST projector
if self.projector is not None:
Expand All @@ -88,15 +82,16 @@ def forward(
]

# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]

# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
Expand All @@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()

vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config

self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype

self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
Expand All @@ -136,7 +127,8 @@ def forward(
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]

pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)

if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
Expand All @@ -145,13 +137,15 @@ def forward(
if self.logit_bias is not None:
pooled_data -= self.logit_bias

flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
]
if self.activation is not None:
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
pooled_data = self.activation(pooled_data) if flags[0] else pooled_data
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]

# scores shape: [batchsize, num_labels]
return scores
# pooled_data shape: [batchsize, num_labels]
return pooled_data
29 changes: 25 additions & 4 deletions vllm/model_executor/layers/pooler/seqwise/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

import torch

from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata

Expand Down Expand Up @@ -86,7 +91,14 @@ def forward(

def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())
head = EmbeddingPoolerHead()

vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = EmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)

return SequencePooler(pooling=pooling, head=head)

Expand All @@ -101,6 +113,15 @@ def pooler_for_classify(
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_seq_pooling_type())

head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = ClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
),
)

return SequencePooler(pooling=pooling, head=head)
55 changes: 23 additions & 32 deletions vllm/model_executor/layers/pooler/tokwise/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import torch
import torch.nn as nn

from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.model_executor.layers.pooler import ActivationFn, ClassifierFn, ProjectorFn
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
Expand Down Expand Up @@ -49,17 +42,17 @@ def forward(


class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
def __init__(
self,
projector: ProjectorFn | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()

# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config

self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype

self.activation = PoolerNormalize()
self.projector = projector
self.head_dtype = head_dtype
self.activation = activation

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
Expand All @@ -73,7 +66,8 @@ def forward_chunk(
if pooled_data is None:
return None

pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]

# Apply ST projector
Expand All @@ -85,7 +79,7 @@ def forward_chunk(
pooled_data = pooled_data[..., : pooling_param.dimensions]

# for normalize
if pooling_param.normalize:
if self.activation is not None and pooling_param.normalize:
pooled_data = self.activation(pooled_data)

# pooled_data shape: [n_tokens, embedding_dimension]
Expand All @@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
logit_bias: float | None = None,
head_dtype: torch.dtype | str | None = None,
activation: ActivationFn | None = None,
) -> None:
super().__init__()

vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config

self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype

self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias = logit_bias
self.head_dtype = head_dtype
self.activation = activation

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
Expand All @@ -123,7 +113,8 @@ def forward_chunk(
if pooled_data is None:
return None

pooled_data = pooled_data.to(self.head_dtype)
if self.head_dtype is not None:
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]

if self.classifier is not None:
Expand All @@ -135,8 +126,8 @@ def forward_chunk(
if self.logit_bias is not None:
scores -= self.logit_bias

if pooling_param.use_activation:
scores = self.act_fn(scores)
if self.activation is not None and pooling_param.use_activation:
scores = self.activation(scores)

# scores shape: [n_token, num_labels]
return scores
29 changes: 25 additions & 4 deletions vllm/model_executor/layers/pooler/tokwise/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

import torch

from vllm.config import PoolerConfig
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata

Expand Down Expand Up @@ -86,7 +91,14 @@ def forward(

def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
head = TokenEmbeddingPoolerHead()

vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
projector=_load_st_projector(model_config),
head_dtype=model_config.head_dtype,
activation=PoolerNormalize(),
)

return TokenPooler(pooling=pooling, head=head)

Expand All @@ -101,6 +113,15 @@ def pooler_for_token_classify(
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())

head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenClassifierPoolerHead(
classifier=classifier,
logit_bias=model_config.pooler_config.logit_bias,
head_dtype=model_config.head_dtype,
activation=resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
),
)

return TokenPooler(pooling=pooling, head=head)
Loading