Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Public API"""

import re
from typing import Callable, List, Optional, Union

Expand Down
33 changes: 24 additions & 9 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, base_url, auth_token=None):
self.base_url = base_url
self.auth_token = auth_token

res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token
)
assert res.status_code == 200
self.model_info = res.json()

Expand All @@ -37,22 +39,24 @@ def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

def commit_lazy_operations(self, s: StreamExecutor):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200

def generate(
Expand Down Expand Up @@ -82,7 +86,9 @@ def generate(

self._add_images(s, data)

res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
Expand Down Expand Up @@ -115,7 +121,12 @@ def generate_stream(
data["stream"] = True
self._add_images(s, data)

response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
)
pos = 0

incomplete_text = ""
Expand Down Expand Up @@ -145,7 +156,9 @@ def select(
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]

Expand All @@ -157,7 +170,9 @@ def select(
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
obj = res.json()
normalized_prompt_logprob = [
Expand All @@ -172,7 +187,7 @@ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

Expand Down
19 changes: 18 additions & 1 deletion python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def get_chat_template_by_model_path(model_path):
)


register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)

register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
Expand Down Expand Up @@ -168,7 +183,7 @@ def get_chat_template_by_model_path(model_path):
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")


Expand All @@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")


@register_chat_template_matching_function
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def to_anthropic_kwargs(self):
)
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop
if isinstance(self.stop, (list, tuple))
else [self.stop],
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/tracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tracing a program."""

import uuid
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/backend_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Backend configurations, may vary with different serving platforms.
"""

from dataclasses import dataclass


Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def generate_chat_conv(
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
# NOTE: Only works for llava
real_content += "<image>\n"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
Expand Down
14 changes: 10 additions & 4 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_multimodal_model,
set_random_seed,
)
from vllm.logger import _default_handler as vllm_default_handler

logger = logging.getLogger("model_rpc")

Expand All @@ -50,6 +51,9 @@ def exposed_init_model(
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_handler.setLevel(
level=getattr(logging, server_args.log_level.upper())
)

# Init model and tokenizer
self.model_config = ModelConfig(
Expand Down Expand Up @@ -83,9 +87,11 @@ def exposed_init_model(
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
(
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
),
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
Expand Down Expand Up @@ -534,7 +540,7 @@ def handle_finished_requests(self, batch: Batch):
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)

# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = {
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size,), dtype=torch.int32, device="cuda"
)

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
Expand All @@ -121,7 +123,9 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.kv_indptr,
Expand All @@ -131,7 +135,9 @@ def init_flashinfer_args(self, tp_size):
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Memory pool."""

import logging

import torch
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only LLaVa model compatible with HuggingFace weights."""

from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -269,7 +270,6 @@ def load_weights(
raise ValueError(f"Unexpected select feature: {self.select_feature}")

# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only Mistral model."""

from sglang.srt.models.llama2 import LlamaForCausalLM


Expand Down
16 changes: 9 additions & 7 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ def __init__(

self.experts = nn.ModuleList(
[
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
)
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts)
]
)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/yivl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only Yi-VL model."""

import os
from typing import List, Optional

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling parameters for text generation."""

from typing import List, Optional, Union

_SAMPLING_EPS = 1e-6
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SRT: SGLang Runtime"""

import asyncio
import json
import multiprocessing as mp
Expand Down Expand Up @@ -493,7 +494,7 @@ def _launch_server():

# Warmup
try:
print("Warmup...", flush=True)
# print("Warmup...", flush=True)
res = requests.post(
url + "/generate",
json={
Expand All @@ -505,8 +506,8 @@ def _launch_server():
},
timeout=60,
)
print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True)
# print(f"Warmup done. model response: {res.json()['text']}")
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port
if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.")
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
port = new_port

# then we check on additional ports
Expand Down Expand Up @@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)

return logit_bias

Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common utilities for testing and benchmarking"""

import numpy as np
import requests
from sglang.backend.openai import OpenAI
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):

if torch.cuda.current_device() != gpu_id:
print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)

Expand Down Expand Up @@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return requests.post(url, json=json, stream=True)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}"
"Authentication": f"Bearer {auth_token}",
}
return requests.post(url, json=json, stream=True, headers=headers)
else:
Expand Down
1 change: 1 addition & 0 deletions test/lang/test_srt_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""

import json
import unittest

Expand Down
Loading