Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def __init__(
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
if (
self.hf_config.architectures[0] in mm_disabled_models
and self.model_impl != ModelImpl.TRANSFORMERS
):
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
Expand All @@ -177,8 +180,14 @@ def __init__(
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
has_multimodal_subconfig = (
self.hf_config is not self.hf_text_config
or hasattr(self.hf_config, "vision_config")
or hasattr(self.hf_config, "audio_config")
)
self.is_multimodal = enable_multimodal and (
is_multimodal_model(self.hf_config.architectures)
or has_multimodal_subconfig
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,11 @@ def __init__(
server_args,
_processor,
transport_mode,
model_config=(
getattr(self.scheduler, "model_config", None)
if self.scheduler is not None
else None
),
skip_mm_pool=not enable_adaptive_dispatch_to_encoder,
)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,8 @@ class TokenizedGenerateReqInput(BaseReq):
# Whether to return entropy
return_entropy: bool = False

token_type_ids: Optional[List[int]] = None

need_wait_for_mm_inputs: bool = False
num_items_assigned: Optional[Dict[Modality, List[int]]] = None

Expand Down
32 changes: 30 additions & 2 deletions python/sglang/srt/managers/multimodal_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import pkgutil

from sglang.srt.configs.model_config import ModelImpl
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.server_args import ServerArgs

Expand Down Expand Up @@ -41,14 +42,41 @@ def import_processors(package_name: str, overwrite: bool = False):


def get_mm_processor(
hf_config, server_args: ServerArgs, processor, transport_mode, **kwargs
hf_config,
server_args: ServerArgs,
processor,
transport_mode,
model_config=None,
**kwargs,
) -> BaseMultimodalProcessor:
model_impl = str(getattr(server_args, "model_impl", "auto")).lower()
uses_transformers_backend = model_impl == "transformers"
if model_impl == "auto" and model_config is not None:
from sglang.srt.model_loader.utils import get_resolved_model_impl

uses_transformers_backend = (
get_resolved_model_impl(model_config) == ModelImpl.TRANSFORMERS
)

for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
if model_cls.__name__ not in hf_config.architectures:
continue
if not uses_transformers_backend or getattr(
processor_cls, "supports_transformers_backend", False
):
return processor_cls(
hf_config, server_args, processor, transport_mode, **kwargs
)

if uses_transformers_backend:
from sglang.srt.multimodal.processors.transformers_auto import (
TransformersAutoMultimodalProcessor,
)

return TransformersAutoMultimodalProcessor(
hf_config, server_args, processor, transport_mode, **kwargs
)

Comment thread
adarshxs marked this conversation as resolved.
raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
Expand Down
57 changes: 43 additions & 14 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torch.distributed import barrier

from sglang.jit_kernel.ngram_embedding import update_token_table
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.configs.model_config import ModelConfig, ModelImpl
from sglang.srt.constants import HEALTH_CHECK_RID_PREFIX
from sglang.srt.constrained.grammar_manager import GrammarManager
from sglang.srt.disaggregation.decode import (
Expand Down Expand Up @@ -185,6 +185,7 @@
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.model_loader.utils import get_resolved_model_impl
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
from sglang.srt.observability.req_time_stats import (
real_time,
Expand Down Expand Up @@ -699,6 +700,9 @@ def init_model_worker(self):

def init_cache_with_memory_pool(self):
server_args = self.server_args
uses_transformers_backend = (
get_resolved_model_impl(self.model_config) == ModelImpl.TRANSFORMERS
)

# Hybrid memory pool
self.is_hybrid_swa = self.tp_worker.is_hybrid_swa
Expand All @@ -718,9 +722,21 @@ def init_cache_with_memory_pool(self):
self.tp_worker.get_memory_pool()
)

# Create cache
self.disable_radix_cache = server_args.disable_radix_cache or (
self.model_config.is_multimodal and uses_transformers_backend
)
if self.disable_radix_cache and not server_args.disable_radix_cache:
logger.warning(
"Radix cache is disabled for multimodal models with the "
"Transformers backend to avoid multimodal prefix-cache mismatches."
)

effective_chunked_prefill_size = server_args.chunked_prefill_size
if self.model_config.is_multimodal and uses_transformers_backend:
effective_chunked_prefill_size = None

params = CacheInitParams(
disable=server_args.disable_radix_cache,
disable=self.disable_radix_cache,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
Expand All @@ -736,14 +752,11 @@ def init_cache_with_memory_pool(self):
enable_mamba_extra_buffer=server_args.enable_mamba_extra_buffer(),
pp_rank=self.pp_rank,
pp_size=self.pp_size,
chunked_prefill_size=server_args.chunked_prefill_size,
chunked_prefill_size=effective_chunked_prefill_size,
sliding_window_size=self.sliding_window_size,
)

if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
if effective_chunked_prefill_size is not None and self.disable_radix_cache:
if not self.is_hybrid_swa:
from sglang.srt.mem_cache.chunk_cache import ChunkCache

Expand Down Expand Up @@ -844,9 +857,22 @@ def init_running_status(self):
self._engine_paused = False

def init_chunked_prefill(self):
# Init chunked prefill
self.chunked_prefill_size = self.server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
uses_transformers_backend = (
get_resolved_model_impl(self.model_config) == ModelImpl.TRANSFORMERS
)
if (
self.chunked_prefill_size is not None
and self.chunked_prefill_size > 0
and self.model_config.is_multimodal
and uses_transformers_backend
):
logger.warning(
"Chunked prefill is disabled for multimodal models with the "
"Transformers backend to avoid partial multimodal chunk mismatches."
)
self.chunked_prefill_size = None
elif self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
self.chunked_prefill_size = None
self.chunked_req = None
self.is_mixed_chunk = (
Expand Down Expand Up @@ -1724,6 +1750,7 @@ def handle_generate_request(
stream=recv_req.stream,
lora_id=recv_req.lora_id,
input_embeds=recv_req.input_embeds,
token_type_ids=recv_req.token_type_ids,
custom_logit_processor=recv_req.custom_logit_processor,
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
Expand Down Expand Up @@ -1806,10 +1833,12 @@ def handle_generate_request(
SessionController.adjust_mm_offsets(recv_req, req, image_inputs)

# The following steps are already fast, execute locally on each rank.
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
# Expand a single image token into multiple dummy tokens for receiving image embeddings.
# The pad function is model-specific and can be None for some backends.
if self.pad_input_ids_func:
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
self._maybe_compute_mrope_positions(req)

Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ def init_tokenizer_and_processor(self):
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor, transport_mode
self.model_config.hf_config,
server_args,
_processor,
transport_mode,
model_config=self.model_config,
)

if server_args.skip_tokenizer_init:
Expand Down Expand Up @@ -747,6 +751,10 @@ async def _tokenize_one_request(

if mm_inputs and "input_ids" in mm_inputs:
input_ids = mm_inputs["input_ids"]
if mm_inputs and "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")
if not isinstance(token_type_ids, list):
token_type_ids = token_type_ids.flatten().tolist()
if (
envs.SGLANG_MM_PRECOMPUTE_HASH.get()
and mm_inputs
Expand Down Expand Up @@ -971,6 +979,7 @@ def _create_tokenized_object(
priority=obj.priority,
extra_key=obj.extra_key,
routing_key=obj.routing_key,
token_type_ids=token_type_ids,
need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs,
num_items_assigned=obj.num_items_assigned,
)
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,16 @@ def _dummy_run(self, batch_size: int, run_ctx=None):

if self.server_args.enable_torch_compile:
set_torch_compile_config()
should_disable_torch_compile = not getattr(
self.model, "_can_torch_compile", True
)
if should_disable_torch_compile:
log_info_on_rank0(
logger,
"Transformers backend model reports it is not torch.compile "
"compatible (e.g. dynamic rope scaling). Disabling torch.compile.",
)
self.server_args.enable_torch_compile = False

if self.eagle_use_aux_hidden_state:
self.model.set_eagle3_layers_to_capture()
Expand Down
Loading
Loading