Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")


@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def forward(self, input_ids, hidden_states, weight, input_metadata):
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(prefill_logprobs, dim=0, dtype=torch.float32)
logprobs_cumsum = torch.cumsum(
prefill_logprobs, dim=0, dtype=torch.float32
)

start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + prefill_logprobs[start]
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_logprobs[start]
)
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
Expand Down
9 changes: 1 addition & 8 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@


class RadixAttention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scaling,
num_kv_heads,
layer_id
):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class BatchStrOut:
class FlushCacheReq:
pass


@dataclass
class DetokenizeReqInput:
input_ids: List[int]
35 changes: 22 additions & 13 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
Expand Down Expand Up @@ -391,8 +391,12 @@ def forward_fill_batch(self, batch: Batch):
logprobs = None
if batch.extend_num_tokens != 0:
# Forward
logits, (prefill_logprobs, normalized_logprobs, last_logprobs) = (
self.model_runner.forward(batch, ForwardMode.EXTEND, batch.return_logprob)
logits, (
prefill_logprobs,
normalized_logprobs,
last_logprobs,
) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
)
if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist()
Expand All @@ -407,7 +411,9 @@ def forward_fill_batch(self, batch: Batch):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
last_logprobs = (
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
)

# Check finish condition
pt = 0
Expand Down Expand Up @@ -482,7 +488,9 @@ def forward_decode_batch(self, batch: Batch):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()
last_logprobs = last_logprobs[
torch.arange(len(reqs)), next_token_ids
].tolist()

# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
Expand Down Expand Up @@ -620,15 +628,16 @@ async def _func(*args, **kwargs):
self.step = async_wrap("step")


def start_model_process(port):
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
)
t.start()
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
)
t.start()


def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
Expand Down
12 changes: 8 additions & 4 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel

import sglang
QUANTIONCONFIG_MAPPING = {'awq': AWQConfig,
'gptq': GPTQConfig}

QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}

logger = logging.getLogger("model_runner")

Expand Down Expand Up @@ -283,9 +283,13 @@ def load_model(self):
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_config['quant_method'])
quant_config_class = QUANTIONCONFIG_MAPPING.get(
hf_quant_config["quant_method"]
)
if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {hf_quant_config['quant_method']}")
raise ValueError(
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
)
quant_config = quant_config_class.from_config(hf_quant_config)
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
Expand Down
31 changes: 19 additions & 12 deletions python/sglang/srt/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def __init__(
2 * [intermediate_size],
bias=False,
gather_output=False,
linear_method=linear_method
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
if hidden_act != "silu":
raise ValueError(
Expand All @@ -74,7 +74,7 @@ def __init__(
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
linear_method: Optional[LinearMethodBase] = None
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -86,18 +86,18 @@ def __init__(

# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
linear_method=linear_method
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
Expand Down Expand Up @@ -143,12 +143,16 @@ def __init__(self, config: QWenConfig, layer_id, linear_method=None):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id,
linear_method=linear_method
linear_method=linear_method,
)

self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, linear_method=linear_method)
self.mlp = QWenMLP(
config.hidden_size,
config.intermediate_size // 2,
linear_method=linear_method,
)

def forward(
self,
Expand Down Expand Up @@ -186,7 +190,10 @@ def __init__(self, config: QWenConfig, linear_method=None):
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i, linear_method=linear_method) for i in range(config.num_hidden_layers)]
[
QWenBlock(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down
27 changes: 20 additions & 7 deletions python/sglang/srt/models/yivl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@

import torch
import torch.nn as nn
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from transformers import CLIPVisionModel, LlavaConfig
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)

from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward


class YiVLForCausalLM(LlavaLlamaForCausalLM):
def __init__(self, *args, **kwargs):
self.config = kwargs["config"]
super().__init__(self.config)

self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
"./", ""
) # Everything after "./"

def load_weights(
self,
Expand All @@ -30,7 +35,9 @@ def load_weights(
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self.vision_tower = CLIPVisionModel.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
model_name_or_path,
torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder,
).cuda()

self.vision_tower.eval()
Expand Down Expand Up @@ -80,14 +87,19 @@ def load_weights(

monkey_path_clip_vision_embed_forward()


class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()

self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size, config.text_config.hidden_size
)
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
self.act = nn.GELU()
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size
)
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)

def forward(self, image_features):
Expand All @@ -98,4 +110,5 @@ def forward(self, image_features):
hidden_states = self.ln_2(hidden_states)
return hidden_states

EntryClass = YiVLForCausalLM

EntryClass = YiVLForCausalLM
7 changes: 5 additions & 2 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1


def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
Expand Down Expand Up @@ -165,7 +166,7 @@ async def gnerate_stream_resp():
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]

if not stream_buffer: # The first chunk
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
Expand Down Expand Up @@ -219,7 +220,9 @@ async def gnerate_stream_resp():
token_logprob_pos = prompt_tokens

logprobs = (
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
await make_openai_style_logprobs(
ret["meta_info"]["token_logprob"][token_logprob_pos:]
)
if request.logprobs is not None
else None
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--max-prefill-num-token",
type=int,
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
)
parser.add_argument(
"--tp-size",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))

return image
return image
2 changes: 2 additions & 0 deletions test/srt/test_httpserver_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import requests


def test_decode(url, return_logprob):
response = requests.post(
url + "/generate",
Expand All @@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
)
print(response.json())


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
Expand Down
4 changes: 3 additions & 1 deletion test/srt/test_httpserver_decode_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import requests


def test_decode_stream(url, return_logprob):
response = requests.post(
url + "/generate",
Expand Down Expand Up @@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
Expand All @@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
prev = len(output)
print("")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
Expand Down
Loading