From 8284a6bd00ba31428aae200a8f5784bfce3b5acc Mon Sep 17 00:00:00 2001 From: DavidBao003 <1353406961@qq.com> Date: Fri, 25 Apr 2025 10:35:30 +0800 Subject: [PATCH] [feature] support for roberta embedding models --- python/sglang/srt/layers/pooler.py | 6 + python/sglang/srt/models/roberta.py | 178 ++++++++++++++++++ .../models/test_encoder_embedding_models.py | 4 +- 3 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/models/roberta.py diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py index 751f09fdd36..7ee8dbcc202 100644 --- a/python/sglang/srt/layers/pooler.py +++ b/python/sglang/srt/layers/pooler.py @@ -12,6 +12,7 @@ class PoolingType(IntEnum): LAST = 0 + CLS = 1 @dataclass @@ -41,6 +42,11 @@ def forward( if self.pooling_type == PoolingType.LAST: last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 pooled_data = hidden_states[last_token_indices] + elif self.pooling_type == PoolingType.CLS: + prompt_lens = forward_batch.extend_seq_lens + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] + pooled_data = hidden_states[first_token_flat_indices] else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/python/sglang/srt/models/roberta.py b/python/sglang/srt/models/roberta.py new file mode 100644 index 00000000000..d9e8c2c7ae1 --- /dev/null +++ b/python/sglang/srt/models/roberta.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.bert import BertEncoder + +RobertaConfig = None + + +class RobertaEmbedding(nn.Module): + + def __init__(self, config: RobertaConfig): + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), + ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) + + def forward( + self, + input_ids: torch.Tensor, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + inputs_embeds=None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + inputs_embeds = self.word_embeddings(input_ids) + + # adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py + + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset : offset + seq_len]) + token_list.append(input_ids[offset : offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange( + positions.size()[0], dtype=torch.long, device=inputs_embeds.device + ) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx) + ) + position_ids = torch.cat(new_pos_list) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=inputs_embeds.device + ) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class XLMRobertaModel(nn.Module): + def __init__( + self, + *, + config: RobertaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.config = config + self.embeddings = RobertaEmbedding(config) + self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="") + self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + assert get_embedding == True + # Your tokenized IDs + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=positions, + seq_lens=forward_batch.seq_lens, + ) + + hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) + pooler_out = self.pooler(hidden_states, forward_batch) + return pooler_out + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name = name.replace("self", "self_attn") + if "pooler" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +# Adapted from transformers +def create_position_ids_from_input_ids( + input_ids, padding_idx, past_key_values_length=0 +): + mask = input_ids.ne(padding_idx).int() + incremental_indices = ( + torch.cumsum(mask, dim=0).type_as(mask) + past_key_values_length + ) * mask + return incremental_indices.long() + padding_idx + + +EntryClass = [XLMRobertaModel] diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py index 4dad0be1513..5202917c4b1 100644 --- a/test/srt/models/test_encoder_embedding_models.py +++ b/test/srt/models/test_encoder_embedding_models.py @@ -25,10 +25,10 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci -MODELS = [("BAAI/bge-small-en", 1, 1e-5)] +MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)] ATTENTION_BACKEND = ["torch_native", "triton"] -BATCH_SIZE = [30] +BATCH_SIZE = [1, 2] TORCH_DTYPES = [torch.float32] sgl_to_st_ratio = []