diff --git a/docs/advanced_features/vlm_query.ipynb b/docs/advanced_features/vlm_query.ipynb index dd35f5eaef8f..24bd7a90bc9f 100644 --- a/docs/advanced_features/vlm_query.ipynb +++ b/docs/advanced_features/vlm_query.ipynb @@ -182,9 +182,8 @@ "from transformers import Qwen2_5_VLForConditionalGeneration\n", "\n", "processor = AutoProcessor.from_pretrained(model_path, use_fast=True)\n", - "vision = (\n", - " Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()\n", - ")" + "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval()\n", + "vision = model.model.visual.cuda()" ] }, { @@ -203,6 +202,7 @@ "precomputed_embeddings = vision(\n", " processor_output[\"pixel_values\"].cuda(), processor_output[\"image_grid_thw\"].cuda()\n", ")\n", + "precomputed_embeddings = precomputed_embeddings.pooler_output\n", "\n", "multi_modal_item = dict(\n", " processor_output,\n", diff --git a/python/pyproject.toml b/python/pyproject.toml index e0e6642cc862..5a61ec1c8af9 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -30,8 +30,6 @@ dependencies = [ "flashinfer_python==0.6.6", # keep it aligned with jit-cache version in Dockerfile "flashinfer_cubin==0.6.6", "gguf", - "hf_transfer", - "huggingface_hub", "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", @@ -72,7 +70,8 @@ dependencies = [ "av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'armv7l')", "torchvision", "tqdm", - "transformers==4.57.1", + "mistral_common>=1.9.0", + "transformers==5.3.0", "uvicorn", "uvloop", "watchfiles", @@ -132,6 +131,7 @@ tracing = [ test = [ "accelerate", + "addict", "bitsandbytes", "expecttest", "jsonlines", @@ -139,7 +139,7 @@ test = [ "matplotlib", "pandas", "parameterized", - "peft", + "peft>=0.18.0", "pytest", "pytest-cov", "diff-cover", diff --git a/python/pyproject_cpu.toml b/python/pyproject_cpu.toml index ae22b112752c..5febf1a68f0b 100644 --- a/python/pyproject_cpu.toml +++ b/python/pyproject_cpu.toml @@ -26,8 +26,6 @@ dependencies = [ "einops", "fastapi", "gguf", - "hf_transfer", - "huggingface_hub", "intel-openmp; platform_machine == 'x86_64'", "interegular", "llguidance>=0.7.11,<0.8.0", @@ -62,7 +60,8 @@ dependencies = [ "torchaudio==2.9.0", "torchvision==0.24.0", "tqdm", - "transformers==4.57.1", + "mistral_common>=1.9.0", + "transformers==5.3.0", "triton==3.5.0", "uvicorn", "uvloop", @@ -83,7 +82,7 @@ test = [ "jsonlines", "matplotlib", "pandas", - "peft", + "peft>=0.18.0", "pytest", "sentence_transformers", ] diff --git a/python/pyproject_npu.toml b/python/pyproject_npu.toml index 94417f6d97be..2722ea7309a3 100644 --- a/python/pyproject_npu.toml +++ b/python/pyproject_npu.toml @@ -26,8 +26,6 @@ dependencies = [ "einops", "fastapi", "gguf", - "hf_transfer", - "huggingface_hub", "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", @@ -57,7 +55,8 @@ dependencies = [ "timm==1.0.16", "torchao==0.9.0", "tqdm", - "transformers==4.57.1", + "mistral_common>=1.9.0", + "transformers==5.3.0", "uvicorn", "uvloop", "xgrammar==0.1.27", @@ -96,7 +95,7 @@ test = [ "jsonlines", "matplotlib", "pandas", - "peft", + "peft>=0.18.0", "pytest", "sentence_transformers", "tabulate", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index c7a68e913c27..cbf20082eea0 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -28,8 +28,6 @@ runtime_common = [ "einops", "fastapi", "gguf", - "hf_transfer", - "huggingface_hub", "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", @@ -59,7 +57,8 @@ runtime_common = [ "timm==1.0.16", "torchao==0.9.0", "tqdm", - "transformers==4.57.1", + "mistral_common>=1.9.0", + "transformers==5.3.0", "uvicorn", "uvloop", "xgrammar==0.1.27", @@ -164,7 +163,7 @@ test = [ "jsonlines", "matplotlib", "pandas", - "peft", + "peft>=0.18.0", "pytest", "sentence_transformers", "tabulate", diff --git a/python/pyproject_xpu.toml b/python/pyproject_xpu.toml index 113bf3eda476..561c4263fe70 100644 --- a/python/pyproject_xpu.toml +++ b/python/pyproject_xpu.toml @@ -31,8 +31,6 @@ dependencies = [ "einops", "fastapi", "gguf", - "hf_transfer", - "huggingface_hub", "interegular", "llguidance>=0.7.11,<0.8.0", "modelscope", @@ -62,7 +60,8 @@ dependencies = [ "timm==1.0.16", "torchao==0.9.0", "tqdm", - "transformers==4.57.1", + "mistral_common>=1.9.0", + "transformers==5.3.0", "uvicorn", "uvloop", # "xgrammar==0.1.24", , xgrammar depends on CUDA PyTorch and Triton only @@ -85,7 +84,7 @@ test = [ "matplotlib", "pandas", "parameterized", - "peft", + "peft>=0.18.0", "pytest", "sentence_transformers", "tabulate", diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index 43fb5180b32b..78bd7423890e 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -30,7 +30,6 @@ def is_cuda_v2(): "numpy", "aiohttp", "fastapi", - "hf_transfer", "huggingface_hub", "interegular", "modelscope", diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py index 7507dc10833d..bdeb2ac45c28 100644 --- a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -12,7 +12,6 @@ from pathlib import Path import filelock -import huggingface_hub.constants import torch from safetensors.torch import safe_open from torch.distributed.tensor import DTensor @@ -37,21 +36,6 @@ temp_dir = tempfile.gettempdir() -def enable_hf_transfer() -> None: - """automatically activates hf_transfer""" - if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: - try: - # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa - - huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True - except ImportError: - pass - - -enable_hf_transfer() - - class DisabledTqdm(tqdm): def __init__(self, *args, **kwargs): diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/llama.py b/python/sglang/multimodal_gen/runtime/models/encoders/llama.py index a9d209231fc9..27ad1a518b87 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/llama.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/llama.py @@ -227,8 +227,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py b/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py index 364b72d59fa5..fb04d9ba53b6 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py @@ -798,6 +798,11 @@ def get_image_features( """ pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + if not isinstance(image_embeds, torch.Tensor): + # In transformers v5, the visual encoder returns BaseModelOutputWithPooling. + # pooler_output contains the spatially merged embeddings (what we need), + # while last_hidden_state contains the raw unmerged output. + image_embeds = image_embeds.pooler_output split_sizes = ( image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 ).tolist() diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py b/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py index b8132e4041c1..2373a31ff714 100644 --- a/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py +++ b/python/sglang/multimodal_gen/runtime/models/encoders/qwen3.py @@ -204,8 +204,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 1000000.0) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 40960) attention_bias = getattr(config, "attention_bias", False) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py b/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py index bdbcdcc144e3..155bbec04c99 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py @@ -324,7 +324,7 @@ def load_modules( ) logger.debug( - "Memory usage of loaded modules (GiB): %s. Available memory: %s", + "Memory usage of loaded modules (GiB): %s. avail mem: %s GB", self.memory_usages, round(current_platform.get_available_gpu_memory(), 2), ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py index 861bdda97beb..9f84db8bf44d 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py @@ -7,6 +7,8 @@ This module contains implementations of image encoding stages for diffusion pipelines. """ +import inspect + import PIL import torch from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution @@ -119,12 +121,21 @@ def forward( all_prompt_embeds = [] all_neg_prompt_embeds = [] + image_processor_call_params = inspect.signature( + self.image_processor.__call__ + ).parameters + image_processor_kwargs = { + k: v + for k, v in image_processor_kwargs.items() + if k in image_processor_call_params + } + for idx, prompt_images in enumerate(per_prompt_images): if not prompt_images: continue cur_kwargs = image_processor_kwargs.copy() - if texts and idx < len(texts): + if texts and idx < len(texts) and "text" in image_processor_call_params: cur_kwargs["text"] = [texts[idx]] image_inputs = self.image_processor( diff --git a/python/sglang/srt/configs/internvl.py b/python/sglang/srt/configs/internvl.py index 3ba9c61c10e0..eaa3f4c6af4e 100644 --- a/python/sglang/srt/configs/internvl.py +++ b/python/sglang/srt/configs/internvl.py @@ -593,7 +593,6 @@ def convert_tokens_to_string(self, tokens): current_sub_tokens.append(token) prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) - out_string = self.clean_up_tokenization(out_string) out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) return out_string[1:] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 099370c3e42f..f302dc4c03c9 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -51,10 +51,15 @@ class ModelImpl(str, Enum): MINDSPORE = "mindspore" -def is_deepseek_nsa(config: PretrainedConfig) -> bool: +def is_deepseek_nsa(config) -> bool: + architectures = ( + config.get("architectures") + if isinstance(config, dict) + else getattr(config, "architectures", None) + ) return ( - config.architectures is not None - and config.architectures[0] + architectures is not None + and architectures[0] in [ "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", @@ -63,7 +68,12 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: "PixtralForConditionalGeneration", "GlmMoeDsaForCausalLM", ] - and getattr(config, "index_topk", None) is not None + and ( + config.get("index_topk") + if isinstance(config, dict) + else getattr(config, "index_topk", None) + ) + is not None ) @@ -458,7 +468,13 @@ def _derive_model_shapes(self): ) if rope_type != "default": mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] + if "factor" not in rope_scaling: + logger.warning( + "rope_scaling (type=%s) missing 'factor', defaulting to 1.0. " + "Check model accuracy.", + rope_type, + ) + scaling_factor = rope_scaling.get("factor", 1.0) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: @@ -504,7 +520,12 @@ def _derive_model_shapes(self): mscale_all_dim = self.hf_config.rope_scaling.get( "mscale_all_dim", False ) - scaling_factor = self.hf_config.rope_scaling["factor"] + if "factor" not in self.hf_config.rope_scaling: + logger.warning( + "BailingMoe rope_scaling missing 'factor', defaulting to 1.0. " + "Check model accuracy.", + ) + scaling_factor = self.hf_config.rope_scaling.get("factor", 1.0) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale elif "SarvamMLAForCausalLM" in self.hf_config.architectures: @@ -521,7 +542,12 @@ def _derive_model_shapes(self): mscale_all_dim = self.hf_config.rope_scaling.get( "mscale_all_dim", False ) - scaling_factor = self.hf_config.rope_scaling["factor"] + if "factor" not in self.hf_config.rope_scaling: + logger.warning( + "SarvamMLA rope_scaling missing 'factor', defaulting to 1.0. " + "Check model accuracy.", + ) + scaling_factor = self.hf_config.rope_scaling.get("factor", 1.0) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale else: diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 9e24a2e5ed60..e95e9543f7f6 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from typing import Any, Dict, Optional, Tuple import torch @@ -26,6 +27,29 @@ from sglang.srt.layers.rotary_embedding.yarn import YaRNScalingRotaryEmbedding from sglang.srt.utils import get_bool_env_var, is_hip +logger = logging.getLogger(__name__) + + +def _get_rope_param(rope_scaling, key, default, scaling_type): + """Get a parameter from rope_scaling dict, warn if missing. + + In transformers v5, config.rope_scaling is an alias for rope_parameters + which may be non-None even for models with no actual scaling (rope_type=default). + When a required key is missing, this logs a warning instead of silently + defaulting, to make config mismatches easier to debug. + """ + if key in rope_scaling: + return rope_scaling[key] + logger.warning( + "rope_scaling (type=%s) missing key '%s', defaulting to %s. " + "This may indicate a v5 config issue — check model accuracy.", + scaling_type, + key, + default, + ) + return default + + _is_hip = is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -111,10 +135,19 @@ def get_rope( ) if scaling_type == "llama3": - scaling_factor = rope_scaling["factor"] - low_freq_factor = rope_scaling["low_freq_factor"] - high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) + low_freq_factor = _get_rope_param( + rope_scaling, "low_freq_factor", 1.0, scaling_type + ) + high_freq_factor = _get_rope_param( + rope_scaling, "high_freq_factor", 4.0, scaling_type + ) + original_max_position = _get_rope_param( + rope_scaling, + "original_max_position_embeddings", + max_position, + scaling_type, + ) rotary_emb = Llama3RotaryEmbedding( head_size, rotary_dim, @@ -162,7 +195,7 @@ def get_rope( dtype, ) elif scaling_type == "linear": - scaling_factor = rope_scaling["factor"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) rotary_emb = LinearScalingRotaryEmbedding( head_size, rotary_dim, @@ -173,7 +206,7 @@ def get_rope( dtype, ) elif scaling_type == "dynamic": - scaling_factor = rope_scaling["factor"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) if "alpha" in rope_scaling: rotary_emb = DynamicNTKAlphaRotaryEmbedding( head_size, @@ -195,8 +228,13 @@ def get_rope( dtype, ) elif scaling_type == "yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) + original_max_position = _get_rope_param( + rope_scaling, + "original_max_position_embeddings", + max_position, + scaling_type, + ) extra_kwargs = { k: v for k, v in rope_scaling.items() @@ -229,8 +267,13 @@ def get_rope( **extra_kwargs, ) elif scaling_type == "deepseek_yarn": - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) + original_max_position = _get_rope_param( + rope_scaling, + "original_max_position_embeddings", + max_position, + scaling_type, + ) extra_kwargs = { k: v for k, v in rope_scaling.items() @@ -257,7 +300,12 @@ def get_rope( elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] + original_max_position = _get_rope_param( + rope_scaling, + "original_max_position_embeddings", + max_position, + scaling_type, + ) extra_kwargs = { k: v for k, v in rope_scaling.items() @@ -321,8 +369,10 @@ def get_rope_cpu( scaling_type == "deepseek_yarn" ), "Only deepseek_yarn is supported for CPU for now" - scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling["original_max_position_embeddings"] + scaling_factor = _get_rope_param(rope_scaling, "factor", 1.0, scaling_type) + original_max_position = _get_rope_param( + rope_scaling, "original_max_position_embeddings", max_position, scaling_type + ) extra_kwargs = { k: v for k, v in rope_scaling.items() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f0de7bef28e0..02500bbacdb8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1838,7 +1838,11 @@ def detokenize_logprob_tokens( ] else: assert self.tokenizer is not None - token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + # Wrap each token ID in its own list for batch_decode to decode them separately + # batch_decode([1, 2, 3]) concatenates tokens, batch_decode([[1], [2], [3]]) decodes separately + token_texts = self.tokenizer.batch_decode( + [[idx] for idx in token_logprobs_idx] + ) return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) def detokenize_top_logprobs_tokens( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6e83184b08af..66b144f6194f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -475,11 +475,12 @@ def init_new( batch.extend_input_logprob_token_ids.to(device, non_blocking=True) ) + num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0 if enable_num_token_non_padded(model_runner.server_args): - ret.num_token_non_padded = torch.tensor( - len(batch.input_ids), dtype=torch.int32 - ).to(device, non_blocking=True) - ret.num_token_non_padded_cpu = len(batch.input_ids) + ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to( + device, non_blocking=True + ) + ret.num_token_non_padded_cpu = num_tokens # For MLP sync if batch.global_num_tokens is not None: diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 576d02e3abfd..f13b82eb4639 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -595,11 +595,19 @@ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module: "Please install it with: pip install accelerate" ) - hf_config = AutoConfig.from_pretrained( - model_config.model_path, - trust_remote_code=True, - local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ) + try: + hf_config = AutoConfig.from_pretrained( + model_config.model_path, + trust_remote_code=True, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + except (KeyError, ValueError): + from sglang.srt.utils.hf_transformers_utils import get_config + + hf_config = get_config( + model_config.model_path, + trust_remote_code=True, + ) with init_empty_weights(): torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) model = AutoModelForCausalLM.from_config( @@ -628,6 +636,7 @@ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module: model = AutoModelForCausalLM.from_pretrained( model_config.model_path, + config=hf_config, device_map=device_map, **model_kwargs, trust_remote_code=True, diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 4be11541a924..4746fc8ccd67 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -68,21 +68,6 @@ logger = logging.getLogger(__name__) -def enable_hf_transfer(): - """automatically activates hf_transfer""" - if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: - try: - # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa - - huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True - except ImportError: - pass - - -enable_hf_transfer() - - # use system-level temp directory for file locks, so that multiple users # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the diff --git a/python/sglang/srt/models/afmoe.py b/python/sglang/srt/models/afmoe.py index 92a11b09af03..b5ce2afbbcd9 100644 --- a/python/sglang/srt/models/afmoe.py +++ b/python/sglang/srt/models/afmoe.py @@ -314,8 +314,8 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) self.rotary_dim = int(self.head_dim * partial_rotary_factor) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) diff --git a/python/sglang/srt/models/apertus.py b/python/sglang/srt/models/apertus.py index ca84264b9362..7a831732e6a4 100644 --- a/python/sglang/srt/models/apertus.py +++ b/python/sglang/srt/models/apertus.py @@ -217,8 +217,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/arcee.py b/python/sglang/srt/models/arcee.py index 5afd5f34f5dd..9ee50f02c3a7 100644 --- a/python/sglang/srt/models/arcee.py +++ b/python/sglang/srt/models/arcee.py @@ -199,8 +199,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 6d060b88182e..ff4698a424fd 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -229,7 +229,7 @@ def __init__( ): super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 38eb17aba673..4b9e39b8295b 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -498,8 +498,8 @@ def __init__( self.head_dim, rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, - base=config.rope_theta, - rope_scaling=config.rope_scaling, + base=config.rope_parameters["rope_theta"], + rope_scaling=config.rope_parameters, ) self.attn = RadixAttention( diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 7c799f5f8400..e23a31b0c60f 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -171,8 +171,8 @@ def __init__( self.max_position_embeddings = getattr( config, "model_max_length", None ) or getattr(config, "max_position_embeddings", 8192) - self.rope_theta = config.rope_theta - self.rope_scaling = getattr(config, "rope_scaling", None) + self.rope_theta = config.rope_parameters["rope_theta"] + self.rope_scaling = config.rope_parameters self.use_qk_norm = getattr(config, "use_qk_norm", False) self.qkv_proj = QKVParallelLinear( self.hidden_size, diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 74de384b3395..12d67d5d8cf9 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -205,7 +205,7 @@ def __init__( self.head_dim = self.d_model // self.total_num_heads self.total_num_kv_heads = config.attn_config.kv_n_heads self.clip_qkv = config.attn_config.clip_qkv - self.rope_theta = config.attn_config.rope_theta + self.rope_theta = config.attn_config.rope_parameters["rope_theta"] self.max_position = config.max_seq_len # pylint: disable=invalid-name diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index ef431e00d460..675bb1673d3b 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -288,8 +288,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8f061714223b..6aa4020e9228 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1200,7 +1200,12 @@ def __init__( if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] + if "factor" not in rope_scaling: + logger.warning( + "DeepSeek rope_scaling missing 'factor', defaulting to 1.0. " + "Check model accuracy.", + ) + scaling_factor = rope_scaling.get("factor", 1.0) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale else: @@ -1495,7 +1500,7 @@ def __init__( self.hidden_size = config.hidden_size self.config = config if hasattr(config, "rope_parameters"): - rope_theta = config.rope_parameters.get("rope_theta") + rope_theta = config.rope_parameters["rope_theta"] assert rope_theta is not None, f"rope_theta not found in config: {config}" rope_type = config.rope_parameters.get("rope_type") rope_scaling = config.rope_parameters if rope_type != "default" else None diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py index 3a61d8fdce0c..73c1558d5ff5 100644 --- a/python/sglang/srt/models/ernie4.py +++ b/python/sglang/srt/models/ernie4.py @@ -155,8 +155,8 @@ def __init__( is_mtp: bool = False, ): super().__init__() - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters rope_is_neox_style = getattr(config, "rope_is_neox_style", False) # Self attention. self.self_attn = Ernie4Attention( diff --git a/python/sglang/srt/models/ernie45_moe_vl.py b/python/sglang/srt/models/ernie45_moe_vl.py index 24b37ce2b666..265cca20ee05 100644 --- a/python/sglang/srt/models/ernie45_moe_vl.py +++ b/python/sglang/srt/models/ernie45_moe_vl.py @@ -368,8 +368,8 @@ def __init__( prefix: str = "", ): super().__init__() - rope_theta = getattr(config, "rope_theta", 500000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters rope_is_neox_style = getattr(config, "rope_is_neox_style", False) freq_allocation = getattr(config, "freq_allocation", 20) max_position_embeddings = getattr(config, "max_position_embeddings", 131072) diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 1e4dfb3df217..0ad8398e6f81 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -182,8 +182,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 500000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/falcon_h1.py b/python/sglang/srt/models/falcon_h1.py index 628f99c6e46e..72f684c2bb9c 100644 --- a/python/sglang/srt/models/falcon_h1.py +++ b/python/sglang/srt/models/falcon_h1.py @@ -133,9 +133,9 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_theta = config.rope_parameters["rope_theta"] self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.rope_scaling = getattr(config, "rope_scaling", None) + self.rope_scaling = config.rope_parameters self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.layer_id = layer_id diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 1ecb5011f71c..af217582fb7d 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -172,7 +172,7 @@ def __init__( head_dim=config.head_dim, layer_id=layer_id, max_position_embeddings=config.max_position_embeddings, - rope_theta=config.rope_theta, + rope_theta=config.rope_parameters["rope_theta"], quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 883eec81fe68..ce9733ed397c 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -217,7 +217,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, - rope_theta=config.rope_theta, + rope_theta=config.rope_parameters["rope_theta"], quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 17c535d73d3f..acd6def10231 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -166,18 +166,36 @@ def __init__( self.is_sliding = config.layer_types[layer_id] == "sliding_attention" + # In transformers v5, rope_parameters is nested per layer type: + # {"sliding_attention": {"rope_theta": 10000}, "full_attention": {"rope_theta": 1000000}} + # In v4 it was flat: {"rope_type": "default", "rope_theta": ...} + rope_params = config.rope_parameters + is_nested = isinstance(rope_params, dict) and "full_attention" in rope_params + # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. - self.rope_theta = config.rope_local_base_freq + if is_nested: + self.rope_theta = rope_params["sliding_attention"].get( + "rope_theta", 10000.0 + ) + else: + self.rope_theta = getattr(config, "rope_local_base_freq", 10000.0) self.rope_scaling = {"rope_type": "default"} # FIXME(mick): idk why vllm does this # self.sliding_window = config.interleaved_sliding_window self.sliding_window = get_attention_sliding_window_size(config) else: # Global attention. Use the values in config.json. - self.rope_theta = config.rope_theta - self.rope_scaling = config.rope_scaling + if is_nested: + self.rope_theta = rope_params["full_attention"].get( + "rope_theta", 1000000.0 + ) + else: + self.rope_theta = ( + rope_params.get("rope_theta", 10000.0) if rope_params else 10000.0 + ) + self.rope_scaling = {"rope_type": "default"} self.sliding_window = None self.attn = RadixAttention( @@ -325,9 +343,10 @@ class Gemma3RotaryEmbedding(nn.Module): def __init__(self, config: Gemma3TextConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type", "default") + rope_scaling = config.rope_parameters + if rope_scaling is not None: + self.rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", "default") ) else: @@ -341,7 +360,10 @@ def __init__(self, config: Gemma3TextConfig, device=None): self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = self.compute_default_rope_parameters + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -373,6 +395,35 @@ def _dynamic_frequency_update(self, position_ids, device): self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len + @staticmethod + def compute_default_rope_parameters(config, device=None, seq_len=None): + """Standard RoPE: no scaling, just base frequency.""" + rope_params = config.rope_parameters + if isinstance(rope_params, dict) and "rope_theta" not in rope_params: + # Nested per-layer-type format; pick the first available theta + for v in rope_params.values(): + if isinstance(v, dict) and "rope_theta" in v: + base = v["rope_theta"] + break + else: + base = 10000.0 + else: + base = rope_params.get("rope_theta", 10000.0) if rope_params else 10000.0 + dim = ( + getattr(config, "head_dim", None) + or config.hidden_size // config.num_attention_heads + ) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + return inv_freq, 1.0 + @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: @@ -447,14 +498,36 @@ def __init__( ) self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Gemma3RotaryEmbedding(config=config) + + # In transformers v5, rope_parameters is nested per layer type: + # {"sliding_attention": {"rope_type": ..., "rope_theta": 10000}, + # "full_attention": {"rope_type": ..., "rope_theta": 1000000}} + # Flatten into the format Gemma3RotaryEmbedding expects. + rope_params = config.rope_parameters + if isinstance(rope_params, dict) and "full_attention" in rope_params: + global_theta = rope_params["full_attention"].get("rope_theta", 1000000.0) + local_theta = rope_params["sliding_attention"].get("rope_theta", 10000.0) + else: + # v4 flat format fallback + global_theta = ( + rope_params.get("rope_theta", 10000.0) if rope_params else 10000.0 + ) + local_theta = getattr(config, "rope_local_base_freq", 10000.0) + + global_config = copy.deepcopy(config) + global_config.rope_parameters = { + "rope_type": "default", + "rope_theta": global_theta, + } + self.rotary_emb = Gemma3RotaryEmbedding(config=global_config) self.gradient_checkpointing = False - # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE - config = copy.deepcopy(config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default"} - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + local_config = copy.deepcopy(config) + local_config.rope_parameters = { + "rope_type": "default", + "rope_theta": local_theta, + } + self.rotary_emb_local = Gemma3RotaryEmbedding(config=local_config) self.layers = make_layers( config.num_hidden_layers, @@ -506,7 +579,7 @@ def forward( class Gemma3ForCausalLM(PreTrainedModel): config_class = Gemma3TextConfig - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Gemma3TextConfig diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 317cc71a3104..94431edaace7 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -420,8 +420,8 @@ def should_apply_lora(self, module_name: str) -> bool: """Skip vision tower and multi_modal_projector for LoRA.""" return bool(self.lora_pattern.match(module_name)) - def tie_weights(self): - return self.language_model.tie_weights() + def tie_weights(self, **kwargs): + return self.language_model.tie_weights(**kwargs) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma3n_causal.py b/python/sglang/srt/models/gemma3n_causal.py index c92f70971498..fdab079caca6 100644 --- a/python/sglang/srt/models/gemma3n_causal.py +++ b/python/sglang/srt/models/gemma3n_causal.py @@ -397,8 +397,8 @@ def __init__( self.head_dim, rotary_dim=self.head_dim, max_position=config.max_position_embeddings, - base=config.rope_theta, - rope_scaling=config.rope_scaling, + base=config.rope_parameters["rope_theta"], + rope_scaling=config.rope_parameters, ) self.sliding_window = config.sliding_window if self.is_sliding else None diff --git a/python/sglang/srt/models/glm4.py b/python/sglang/srt/models/glm4.py index ba40a1f7446a..d6e81f619e2e 100644 --- a/python/sglang/srt/models/glm4.py +++ b/python/sglang/srt/models/glm4.py @@ -217,20 +217,9 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - - rp = getattr(config, "rope_parameters", None) - if isinstance(rp, dict): - rope_theta = rp.get("rope_theta", getattr(config, "rope_theta", 1000000)) - partial_rotary_factor = rp.get( - "partial_rotary_factor", - getattr(config, "partial_rotary_factor", 0.5), - ) - rope_scaling = getattr(config, "rope_scaling", None) - else: - rope_theta = getattr(config, "rope_theta", 1000000) - rope_scaling = getattr(config, "rope_scaling", None) - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) - + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 0.5) bias = getattr(config, "attention_bias", True) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) head_dim = getattr(config, "head_dim", None) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 85f13132c929..0c1e1d00084a 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -684,8 +684,8 @@ def __init__( nn.Module.__init__(self) self.hidden_size = config.hidden_size self.config = config - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters partial_rotary_factor = getattr( getattr(config, "rope_parameters", None), "partial_rotary_factor", None ) or getattr(config, "partial_rotary_factor", 0.5) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 96caaa65b57c..04f4e4e7c304 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -374,8 +374,8 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index 19252dc8db62..63a9ebec5f3a 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -187,8 +187,8 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size self.residual_multiplier = config.residual_multiplier - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py index d65b9ec06d31..ffeb13742c86 100644 --- a/python/sglang/srt/models/granitemoe.py +++ b/python/sglang/srt/models/granitemoe.py @@ -187,7 +187,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.self_attn = GraniteMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index f82642777447..9a096cc13079 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -477,7 +477,7 @@ def __init__( self.layer_id = layer_id self.alt_stream = alt_stream or torch.cuda.Stream() - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.self_attn = Grok1Attention( config=config, hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 0128d009564d..7473062b0b3f 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -402,8 +402,8 @@ def __init__( if isinstance(config.intermediate_size, int) else config.intermediate_size[layer_id] ) - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/iquest_loopcoder.py b/python/sglang/srt/models/iquest_loopcoder.py index 240aa5306a29..b69a5332a80a 100644 --- a/python/sglang/srt/models/iquest_loopcoder.py +++ b/python/sglang/srt/models/iquest_loopcoder.py @@ -166,8 +166,8 @@ def __init__( prefix=add_prefix("o_proj", prefix), ) - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr( config, "max_position_embeddings", max_position ) diff --git a/python/sglang/srt/models/jet_nemotron.py b/python/sglang/srt/models/jet_nemotron.py index 513f2ce3759a..1e6d2ec87e1c 100644 --- a/python/sglang/srt/models/jet_nemotron.py +++ b/python/sglang/srt/models/jet_nemotron.py @@ -374,8 +374,8 @@ def __init__( self.head_dim, rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, - base=int(self.config.rope_theta), - rope_scaling=self.config.rope_scaling, + base=int(self.config.rope_parameters["rope_theta"]), + rope_scaling=self.config.rope_parameters, ) match self.config.layer_types[layer_id]: diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index 8a271c33606e..b6205f302540 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -124,13 +124,13 @@ def __init__( if rope_parameters is not None and "rope_theta" in rope_parameters: rope_theta = rope_parameters["rope_theta"] else: - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=self.head_dim, max_position=getattr(config, "max_position_embeddings", 8192), - rope_scaling=getattr(config, "rope_scaling", None), + rope_scaling=config.rope_parameters, base=rope_theta, is_neox_style=True, dtype=torch.get_default_dtype(), diff --git a/python/sglang/srt/models/llada2.py b/python/sglang/srt/models/llada2.py index 7294524e3495..041f42db4716 100644 --- a/python/sglang/srt/models/llada2.py +++ b/python/sglang/srt/models/llada2.py @@ -490,8 +490,8 @@ def __init__( self.head_dim, rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, - base=config.rope_theta, - rope_scaling=config.rope_scaling, + base=config.rope_parameters["rope_theta"], + rope_scaling=config.rope_parameters, ) self.attn = RadixAttention( diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 01e934dcc096..d8810c508c48 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -252,8 +252,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 22749a16333a..f0a237aa6e7a 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -366,8 +366,8 @@ def __init__( super().__init__() self.layer_id = layer_id self.hidden_size = config.hidden_size - rope_theta = config.rope_theta - rope_scaling = config.rope_scaling + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = config.max_position_embeddings self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 49f938a1c5fe..45737dad2f8d 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -111,14 +111,13 @@ def __init__( super().__init__() self.config = config + rope_scaling = config.rope_parameters self.is_mrope_enabled = ( - hasattr(config, "rope_scaling") - and config.rope_scaling is not None - and "mrope_section" in config.rope_scaling + rope_scaling is not None and "mrope_section" in rope_scaling ) # fix rope_scaling for qwen2.5-vl if self.is_mrope_enabled: - config.rope_scaling["rope_type"] = "default" + config.rope_parameters["rope_scaling"]["rope_type"] = "default" self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index 7a21f5d7b425..c55c741add99 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -329,8 +329,8 @@ def __init__( v_head_dim=config.v_head_dim, q_lora_rank=config.q_lora_rank, kv_lora_rank=config.kv_lora_rank, - rope_theta=config.rope_theta, - rope_scaling=getattr(config, "rope_scaling", None), + rope_theta=config.rope_parameters["rope_theta"], + rope_scaling=None, max_position_embeddings=config.max_position_embeddings, quant_config=( None diff --git a/python/sglang/srt/models/longcat_flash_nextn.py b/python/sglang/srt/models/longcat_flash_nextn.py index 12c9cb13fae9..abaf27855611 100644 --- a/python/sglang/srt/models/longcat_flash_nextn.py +++ b/python/sglang/srt/models/longcat_flash_nextn.py @@ -132,7 +132,7 @@ def __init__( v_head_dim=config.v_head_dim, q_lora_rank=config.q_lora_rank, kv_lora_rank=config.kv_lora_rank, - rope_theta=config.rope_theta, + rope_theta=config.rope_parameters["rope_theta"], rope_scaling=None, max_position_embeddings=config.max_position_embeddings, quant_config=quant_config, diff --git a/python/sglang/srt/models/midashenglm.py b/python/sglang/srt/models/midashenglm.py index bc758a2c3086..9d349716f5af 100644 --- a/python/sglang/srt/models/midashenglm.py +++ b/python/sglang/srt/models/midashenglm.py @@ -476,18 +476,14 @@ def __init__( ) -> None: super().__init__() self.config = config - if ( - hasattr(config.text_config, "rope_scaling") - and config.text_config.rope_scaling - ): - if "mrope_section" in config.text_config.rope_scaling: + rope_scaling = config.text_config.rope_parameters + if rope_scaling: + if "mrope_section" in rope_scaling: new_rope_scaling = { - k: v - for k, v in config.text_config.rope_scaling.items() - if k != "mrope_section" + k: v for k, v in rope_scaling.items() if k != "mrope_section" } - config.text_config.rope_scaling = ( + config.text_config.rope_parameters["rope_scaling"] = ( new_rope_scaling if new_rope_scaling else None ) self.audio_encoder = DashengAudioTransformer( diff --git a/python/sglang/srt/models/mimo_v2_flash.py b/python/sglang/srt/models/mimo_v2_flash.py index d6f5eb07f4de..9902b63d487f 100644 --- a/python/sglang/srt/models/mimo_v2_flash.py +++ b/python/sglang/srt/models/mimo_v2_flash.py @@ -573,8 +573,16 @@ def __init__( self.hidden_size = config.hidden_size self.layer_id = layer_id - rope_theta = getattr(config, "rope_theta", 1000000) + rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + # In v5, rope_scaling is a property alias for rope_parameters and returns + # a standardized dict even when there's no actual scaling. Treat the + # "default" (no-op) type as None so factory.py uses plain RotaryEmbedding. + if ( + isinstance(rope_scaling, dict) + and rope_scaling.get("rope_type") == "default" + ): + rope_scaling = None max_position_embeddings = getattr(config, "max_position_embeddings", 32768) if self.is_swa_layer(): @@ -792,7 +800,7 @@ def __init__( ) -> None: super().__init__() self.config = config - self.padding_idx = config.pad_token_id + self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size self.pp_group = get_pp_group() diff --git a/python/sglang/srt/models/mimo_v2_flash_nextn.py b/python/sglang/srt/models/mimo_v2_flash_nextn.py index 18b5453953c0..098649f956f2 100644 --- a/python/sglang/srt/models/mimo_v2_flash_nextn.py +++ b/python/sglang/srt/models/mimo_v2_flash_nextn.py @@ -64,8 +64,13 @@ def __init__( self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 1000000) + rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + if ( + isinstance(rope_scaling, dict) + and rope_scaling.get("rope_type") == "default" + ): + rope_scaling = None max_position_embeddings = getattr(config, "max_position_embeddings", 32768) self.self_attn = MiMoV2Attention( diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index e7c94c85d0b2..27c2ba2e8b83 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -176,8 +176,8 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 95dca19da009..f03165053f84 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -305,8 +305,8 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = MiniCPM3AttentionMLA( config=config, diff --git a/python/sglang/srt/models/minimax_m2.py b/python/sglang/srt/models/minimax_m2.py index 7aad87c640a0..9b50d3d070d1 100644 --- a/python/sglang/srt/models/minimax_m2.py +++ b/python/sglang/srt/models/minimax_m2.py @@ -566,7 +566,8 @@ def __init__( self.scaling = self.head_dim**-0.5 # RoPE settings - support partial RoPE - self.rope_theta = getattr(config, "rope_theta", 10000) + # FIXME: minimax_m2 config use external config that not compatible with transformers v5 + self.rope_theta = config.rope_theta self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_dim = getattr( config, "rotary_dim", self.head_dim diff --git a/python/sglang/srt/models/ministral3.py b/python/sglang/srt/models/ministral3.py index 460c7b30fb5e..8d678ec543d1 100644 --- a/python/sglang/srt/models/ministral3.py +++ b/python/sglang/srt/models/ministral3.py @@ -54,11 +54,7 @@ def __init__( bias, ) # Ministral3 specific: llama 4 style scaling beta - self.llama_4_scaling_beta = None - if hasattr(config, "rope_parameters") and config.rope_parameters: - self.llama_4_scaling_beta = config.rope_parameters.get( - "llama_4_scaling_beta" - ) + self.llama_4_scaling_beta = config.rope_parameters.get("llama_4_scaling_beta") # sliding window self.sliding_window = getattr(config, "sliding_window", None) @@ -107,12 +103,8 @@ def __init__(self, config, layer_id=0, quant_config=None, prefix=""): num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, layer_id=layer_id, - rope_theta=getattr(config, "rope_parameters", {}).get( - "rope_theta", 1000000.0 - ), - rope_scaling=getattr( - config, "rope_parameters", {} - ), # rope_scaling is rope_parameters in Ministral3Config + rope_theta=config.rope_parameters["rope_theta"], + rope_scaling=config.rope_parameters, # rope_scaling is rope_parameters in Ministral3Config max_position_embeddings=getattr( config, "original_max_position_embeddings", 16384 ), diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index c4f3e4c446f7..16d3c3e7c5db 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -208,7 +208,7 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.self_attn = MixtralAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 10bfa5068de2..7423aa08534a 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -261,7 +261,7 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.self_attn = MixtralAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 0913f9adfd2d..bfb618e758f7 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -305,7 +305,7 @@ def __init__(self, config): frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y freq_dim = config.hidden_size // config.num_attention_heads // 2 rope_freq = 1.0 / ( - config.rope_theta + config.rope_parameters["rope_theta"] ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim) ) freqs_x = ( diff --git a/python/sglang/srt/models/nemotron_nas.py b/python/sglang/srt/models/nemotron_nas.py index ac1ccd231dc3..904d9d26361c 100644 --- a/python/sglang/srt/models/nemotron_nas.py +++ b/python/sglang/srt/models/nemotron_nas.py @@ -70,8 +70,8 @@ def __init__( self._is_no_op_ffn = block_config.ffn.no_op self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 091b08e8ac1c..0a9b2d525dac 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -68,7 +68,7 @@ def __init__( self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta + self.rope_theta = config.rope_parameters["rope_theta"] self.clip_qkv = config.clip_qkv # Attention input projection. Projects x -> (q, k, v) diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index 2b1c6fa89fa0..512ed0b64290 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -99,7 +99,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta + self.rope_theta = config.rope_parameters["rope_theta"] # Attention input projection. Projects x -> (q, k, v) self.qkv_proj = QKVParallelLinear( diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index a74a2968daef..33c57b80f5b1 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -204,8 +204,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = OlmoeAttention( diff --git a/python/sglang/srt/models/orion.py b/python/sglang/srt/models/orion.py index 2aee4b9aaf44..510f0ca2bddb 100644 --- a/python/sglang/srt/models/orion.py +++ b/python/sglang/srt/models/orion.py @@ -165,8 +165,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = OrionAttention( hidden_size=self.hidden_size, diff --git a/python/sglang/srt/models/persimmon.py b/python/sglang/srt/models/persimmon.py index 5f8885e716e5..5d2585c63031 100644 --- a/python/sglang/srt/models/persimmon.py +++ b/python/sglang/srt/models/persimmon.py @@ -65,7 +65,7 @@ def __init__( self.num_heads = self.total_num_heads // tensor_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta + self.rope_theta = config.rope_parameters["rope_theta"] self.partial_rotary_factor = config.partial_rotary_factor self.is_causal = True diff --git a/python/sglang/srt/models/phi.py b/python/sglang/srt/models/phi.py index 5679bc987812..55188be8d254 100644 --- a/python/sglang/srt/models/phi.py +++ b/python/sglang/srt/models/phi.py @@ -63,7 +63,7 @@ def __init__( ) assert rotary_dim % 2 == 0 - rope_theta = getattr(config, "rope_theta", 10000.0) + rope_theta = config.rope_parameters["rope_theta"] max_position_embeddings = getattr(config, "max_position_embeddings", 2048) self.rotary_emb = get_rope( self.head_size, diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 9ac855c492f6..cf049c43e13d 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -153,8 +153,8 @@ def __init__( prefix=add_prefix("o_proj", prefix), ) - if getattr(self.config, "rope_scaling", None) is not None: - rope_scaling = self.config.rope_scaling + rope_scaling = self.config.rope_parameters + if rope_scaling is not None: for key in rope_scaling: if isinstance(rope_scaling[key], list): rope_scaling[key] = tuple(rope_scaling[key]) diff --git a/python/sglang/srt/models/phimoe.py b/python/sglang/srt/models/phimoe.py index 0d147c2b1783..a359483de3ef 100644 --- a/python/sglang/srt/models/phimoe.py +++ b/python/sglang/srt/models/phimoe.py @@ -336,7 +336,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) + rope_theta = config.rope_parameters["rope_theta"] self.self_attn = PhiMoEAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -349,7 +349,7 @@ def __init__( layer_id=layer_id, attention_bias=config.attention_bias, quant_config=quant_config, - rope_scaling=config.rope_scaling, + rope_scaling=config.rope_parameters, prefix=add_prefix("self_attn", prefix), ) self.block_sparse_moe = PhiMoE( diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 206908b49001..43944b576e8f 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -162,8 +162,8 @@ def __init__( super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters self.attn = QWenAttention( config.hidden_size, config.num_attention_heads, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index a3cfde4d6301..9574186e1caa 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -52,6 +52,7 @@ ) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.utils.hf_transformers_utils import get_rope_config Qwen2Config = None @@ -201,8 +202,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 1000000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta, rope_scaling = get_rope_config(config) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) head_dim = getattr(config, "head_dim", None) dual_chunk_attention_config = getattr( @@ -269,7 +269,6 @@ def __init__( ) -> None: super().__init__() self.config = config - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.pp_group = get_pp_group() diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index ed94416e5e4b..f64782e49aa0 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -82,6 +82,7 @@ make_layers, use_intel_amx_backend, ) +from sglang.srt.utils.hf_transformers_utils import get_rope_config logger = logging.getLogger(__name__) @@ -449,8 +450,7 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta, rope_scaling = get_rope_config(config) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) qkv_bias = getattr(config, "qkv_bias", True) dual_chunk_attention_config = getattr( @@ -571,8 +571,6 @@ def __init__( ) -> None: super().__init__() self.config = config - - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.pp_group = get_pp_group() diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index b6b955ae6c7c..4f018ea52d2f 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -216,8 +216,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 1000000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 32768) head_dim = getattr(config, "head_dim", None) self.self_attn = Qwen3Attention( diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 73f22f9be2c7..2ff7c1c2bf78 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -80,7 +80,7 @@ make_layers, set_weight_attrs, ) -from sglang.srt.utils.hf_transformers_utils import get_processor +from sglang.srt.utils.hf_transformers_utils import get_processor, get_rope_config logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -444,15 +444,14 @@ def __init__( self.scaling = self.head_dim**-0.5 self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - if hasattr(config, "rope_parameters"): - self.rope_scaling = getattr(config, "rope_parameters", None) - else: - self.rope_scaling = getattr(config, "rope_scaling", None) - - self.rope_theta = self.rope_scaling.get("rope_theta", 10000) - self.partial_rotary_factor = self.rope_scaling.get("partial_rotary_factor", 1.0) + self.rope_theta, rope_scaling = get_rope_config(config) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) self.layer_id = layer_id + # If rope_scaling doesn't specify a scaling type, treat as no scaling + if rope_scaling and not ("rope_type" in rope_scaling or "type" in rope_scaling): + rope_scaling = None + self.attn_output_gate = getattr(config, "attn_output_gate", True) if self.attn_output_gate: logger.warning_once("using attn output gate!") @@ -461,7 +460,7 @@ def __init__( head_size=self.head_dim, rotary_dim=self.head_dim, max_position=self.max_position_embeddings, - rope_scaling=self.rope_scaling, + rope_scaling=rope_scaling, base=self.rope_theta, partial_rotary_factor=self.partial_rotary_factor, is_neox_style=True, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 048a000d5d99..a1fc685fa38f 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -115,12 +115,19 @@ def compute_yarn_parameters( attention_factor: float, the post-processing scaling factor applied to the computed cos/sin """ - # The config does not contain rope_scaling, which means the model is not using yarn - rope_scaling = getattr(config, "rope_scaling", None) + # The config does not contain rope_scaling, which means the model is not using yarn. + # In transformers v5, rope_parameters is never None (even for default rope), so also + # check rope_type to distinguish actual yarn configs from plain rotary embeddings. + rope_scaling = getattr(config, "rope_parameters", None) + if rope_scaling is None: + rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is None: return 1.0, 0, 0, 1.0 + rope_type = rope_scaling.get("rope_type") or rope_scaling.get("type") or "default" + if rope_type == "default": + return 1.0, 0, 0, 1.0 - base = config.rope_theta + base = rope_scaling.get("rope_theta") or getattr(config, "rope_theta", 10000) partial_rotary_factor = ( config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") @@ -130,7 +137,7 @@ def compute_yarn_parameters( config, "head_dim", config.hidden_size // config.num_attention_heads ) dim = int(head_dim * partial_rotary_factor) - factor = getattr(rope_scaling, "factor", 1.0) + factor = rope_scaling.get("factor", 1.0) attention_factor = rope_scaling.get("attention_factor") mscale = rope_scaling.get("mscale") mscale_all_dim = rope_scaling.get("mscale_all_dim") @@ -559,7 +566,7 @@ def forward_prepare_native( def apply_qk_norm_rope(self, qkv, positions, forward_batch): use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16 if use_fused: - theta = getattr(self.config, "rope_theta", 10000.0) + theta = self.config.rope_parameters["rope_theta"] positions = ( positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous() ) @@ -684,8 +691,8 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads diff --git a/python/sglang/srt/models/solar.py b/python/sglang/srt/models/solar.py index 8f85ad587ab0..623dcb7d6a79 100644 --- a/python/sglang/srt/models/solar.py +++ b/python/sglang/srt/models/solar.py @@ -194,8 +194,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 2adcfe92ffc5..328d25394000 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -144,14 +144,14 @@ def __init__( self.head_dim, rotary_dim=self.rotary_ndims, max_position=self.config.max_position_embeddings, - base=self.config.rope_theta, + base=self.config.rope_parameters["rope_theta"], ) else: self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, max_position=self.config.max_position_embeddings, - base=self.config.rope_theta, + base=self.config.rope_parameters["rope_theta"], dtype=torch.float32, ) self.attn = RadixAttention( diff --git a/python/sglang/srt/models/starcoder2.py b/python/sglang/srt/models/starcoder2.py index 2ad4351e9ecf..e5cba190deb8 100644 --- a/python/sglang/srt/models/starcoder2.py +++ b/python/sglang/srt/models/starcoder2.py @@ -81,7 +81,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.rope_theta = config.rope_theta + self.rope_theta = config.rope_parameters["rope_theta"] self.max_position_embeddings = config.max_position_embeddings self.use_bias = config.use_bias diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 3ab2354637e8..d839d8fa8e7b 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -290,8 +290,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 14b327bd1a2c..ce9612b8c208 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -274,8 +274,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index f84755b03635..817028a10948 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -181,8 +181,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None ): diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index e4489055f636..e96721b91ea4 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -291,8 +291,8 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = config.rope_parameters["rope_theta"] + rope_scaling = config.rope_parameters max_position_embeddings = getattr(config, "max_position_embeddings", 8192) num_key_value_heads = getattr( config, "num_key_value_heads", config.num_attention_heads diff --git a/python/sglang/srt/multimodal/mm_utils.py b/python/sglang/srt/multimodal/mm_utils.py index 2e402ccf50ca..4d8d8d0369c7 100644 --- a/python/sglang/srt/multimodal/mm_utils.py +++ b/python/sglang/srt/multimodal/mm_utils.py @@ -237,10 +237,11 @@ def process_anyres_image(image, processor, grid_pinpoints): best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) - # For Siglip processor, only have size but no crop size + # For Siglip processor, only have size but no crop size. + # In transformers v5, crop_size may exist but be None. crop_size = ( processor.crop_size["height"] - if "crop_size" in processor.__dict__ + if getattr(processor, "crop_size", None) is not None else processor.size["height"] ) shortest_edge = ( @@ -257,6 +258,10 @@ def process_anyres_image(image, processor, grid_pinpoints): processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0] for image_patch in image_patches ] + # In transformers v5, image processors may return torch.Tensor instead of numpy arrays + image_patches = [ + p.numpy() if isinstance(p, torch.Tensor) else p for p in image_patches + ] return np.stack(image_patches, axis=0) diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 83afdcb97655..350ac9a06732 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Union import numpy as np +import torch from transformers.models.auto.processing_auto import ( PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES, ) @@ -50,8 +51,11 @@ def _process_single_image_task( # It is a video with multiple images image_hash = hash(url) pixel_values = image_processor(image)["pixel_values"] - for _ in range(len(pixel_values)): - pixel_values[_] = pixel_values[_].astype(np.float16) + for i in range(len(pixel_values)): + v = pixel_values[i] + if isinstance(v, torch.Tensor): + v = v.numpy() + pixel_values[i] = v.astype(np.float16) pixel_values = np.stack(pixel_values, axis=0) return pixel_values, image_hash, image_size else: @@ -75,6 +79,8 @@ def _process_single_image_task( else: pixel_values = image_processor(image)["pixel_values"][0] + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.numpy() if isinstance(pixel_values, np.ndarray): pixel_values = pixel_values.astype(np.float16) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 1b9a709fc7a6..9f30bee01e99 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -129,6 +129,59 @@ def download_from_hf( return snapshot_download(model_path, allow_patterns=allow_patterns) +def get_rope_config(config): + """Get (rope_theta, rope_scaling) from config, supporting both v4 and v5. + + In transformers v5, rope_theta/rope_scaling are accessed via the computed + property config.rope_parameters. Trust-remote-code configs or parent configs + passed to sub-models may not have this property or may return None. + Falls back to the v4-style config.rope_theta / config.rope_scaling attributes. + """ + rope_params = getattr(config, "rope_parameters", None) + if rope_params is not None: + return rope_params["rope_theta"], rope_params + return config.rope_theta, getattr(config, "rope_scaling", None) + + +def _patch_text_config(parent_config: PretrainedConfig, text_config): + """Synchronize standard attributes between parent config and text sub-config. + + In transformers v5, the "untangle config" refactor removed automatic + inheritance of top-level PretrainedConfig attributes (pad_token_id, + tie_word_embeddings, etc.) from sub-configs. Downstream code expects + these attributes to be present on both configs (some models pass the + parent directly to the language model, others pass the text sub-config), + so we propagate in both directions when an attribute is missing. + (See https://github.com/huggingface/transformers/pull/41541) + """ + # Some models store text_config as a plain dict rather than a + # PretrainedConfig object. Convert to PretrainedConfig so downstream + # code can use attribute access uniformly (e.g. config.hidden_size). + if isinstance(text_config, dict): + text_config = PretrainedConfig(**text_config) + # Propagate any parent-level torch_dtype so weight loading uses the + # correct precision. + if not hasattr(text_config, "torch_dtype") and hasattr( + parent_config, "torch_dtype" + ): + text_config.torch_dtype = parent_config.torch_dtype + + _ATTRS_TO_PROPAGATE = [ + "pad_token_id", + "bos_token_id", + "eos_token_id", + "tie_word_embeddings", + ] + for attr in _ATTRS_TO_PROPAGATE: + parent_has = hasattr(parent_config, attr) + text_has = hasattr(text_config, attr) + if parent_has and not text_has: + setattr(text_config, attr, getattr(parent_config, attr)) + elif text_has and not parent_has: + setattr(parent_config, attr, getattr(text_config, attr)) + return text_config + + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. @@ -143,20 +196,36 @@ def get_hf_text_config(config: PretrainedConfig): setattr(config, "dtype", torch.float16) return config + text_config = None + + # Some models (e.g. DeepSeek-OCR) store sub-configs as plain dicts. + # Convert to PretrainedConfig early so hasattr() checks and asserts work. + for _attr in ("text_config", "llm_config", "language_config"): + _sub = getattr(config, _attr, None) + if isinstance(_sub, dict): + _converted = PretrainedConfig(**_sub) + # Propagate torch_dtype from parent so weight loading uses correct precision. + if ( + getattr(_converted, "torch_dtype", None) is None + and getattr(config, "torch_dtype", None) is not None + ): + _converted.torch_dtype = config.torch_dtype + setattr(config, _attr, _converted) + if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early # if transformers config doesn't align with this assumption. assert hasattr(config.text_config, "num_attention_heads") - return config.text_config + text_config = config.text_config if hasattr(config, "llm_config"): # PointsV1.5 Chat Model assert hasattr(config.llm_config, "num_attention_heads") - return config.llm_config + text_config = config.llm_config if hasattr(config, "language_config"): - return config.language_config + text_config = config.language_config if hasattr(config, "thinker_config"): # qwen2.5 omni thinker_config = config.thinker_config @@ -166,12 +235,19 @@ def get_hf_text_config(config: PretrainedConfig): "torch_dtype", getattr(thinker_config, "torch_dtype", None), ) - return thinker_config.text_config - return thinker_config + text_config = thinker_config.text_config + else: + text_config = thinker_config + if hasattr(config, "llm_config"): - return config.llm_config - else: - return config + text_config = config.llm_config + + # Ensure rope_scaling dicts have "type" for remote-code compat (v5). + normalize_rope_scaling_compat(config) + + if text_config is not None: + return _patch_text_config(config, text_config) + return config # Temporary hack for DeepSeek-V3.2 model @@ -255,6 +331,13 @@ def _override_deepseek_ocr_v_head_dim(config: DeepseekVLV2Config) -> None: if config.text_config.v_head_dim == 0: V_HEAD_DIM_PATCH = 128 config.text_config.v_head_dim = V_HEAD_DIM_PATCH + # Also fix language_config so get_hf_text_config (which may prefer it + # over text_config) stays consistent. + lc = getattr(config, "language_config", None) + if isinstance(lc, dict): + lc["v_head_dim"] = V_HEAD_DIM_PATCH + elif hasattr(lc, "v_head_dim"): + lc.v_head_dim = V_HEAD_DIM_PATCH logger.warning( f"Overriding deepseek-ocr's v_head_dim from 0 to {V_HEAD_DIM_PATCH} to avoid potential issues." ) @@ -273,6 +356,92 @@ def _override_v_head_dim_if_zero(config: PretrainedConfig, patch: int = 128) -> ) +def _ensure_clean_up_tokenization_compat() -> None: + """Re-add ``clean_up_tokenization`` removed in transformers v5. + + Remote-code tokenizers (e.g. InternLM2Tokenizer) call + ``self.clean_up_tokenization()`` which was a static method on + ``PreTrainedTokenizerBase`` in v4 but removed in v5. Patch it back + so existing HuggingFace Hub tokenizer code keeps working. + """ + if hasattr(PreTrainedTokenizerBase, "clean_up_tokenization"): + return + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + PreTrainedTokenizerBase.clean_up_tokenization = clean_up_tokenization + + +# Apply immediately so all code paths (get_tokenizer, get_processor, +# and any external callers) benefit without needing an explicit call. +_ensure_clean_up_tokenization_compat() + + +def _ensure_is_torch_fx_available_compat() -> None: + """Re-add ``is_torch_fx_available`` removed in transformers v5. + + Remote-code models (e.g. MiniCPM-V) import ``is_torch_fx_available`` + from ``transformers.utils.import_utils``. The function was removed + in v5. Patch it back so existing HuggingFace Hub model code keeps + working. torch.fx is always available in PyTorch >= 2.0. + """ + import transformers.utils.import_utils as _import_utils + + if hasattr(_import_utils, "is_torch_fx_available"): + return + + _import_utils.is_torch_fx_available = lambda: True + + +_ensure_is_torch_fx_available_compat() + + +def normalize_rope_scaling_compat(config: "PretrainedConfig") -> None: + """Ensure rope_scaling dicts have ``"type"`` alongside ``"rope_type"``. + + Transformers v5 standardises rope_scaling to use ``"rope_type"`` and may + omit the legacy ``"type"`` key. Remote-code models (e.g. Kimi-VL) still + read ``rope_scaling["type"]``, causing a ``KeyError``. This helper adds + ``"type"`` from ``"rope_type"`` whenever it is missing, recursively across + the config and all its sub-configs. + """ + + def _patch(cfg): + try: + rs = getattr(cfg, "rope_scaling", None) + except AttributeError: + rs = None + if isinstance(rs, dict) and "rope_type" in rs and "type" not in rs: + rs["type"] = rs["rope_type"] + # Recurse into sub-configs + for attr in ( + "text_config", + "llm_config", + "language_config", + "vision_config", + "thinker_config", + ): + sub = getattr(cfg, attr, None) + if sub is not None: + _patch(sub) + + _patch(config) + + def _ensure_llama_flash_attention2_compat() -> None: """Ensure LlamaFlashAttention2 symbol exists for remote code compatibility.""" try: @@ -284,6 +453,24 @@ def _ensure_llama_flash_attention2_compat() -> None: modeling_llama.LlamaFlashAttention2 = modeling_llama.LlamaAttention +def _ensure_gguf_version(): + """Workaround for transformers v5 bug where is_gguf_available() fails + when the gguf package lacks __version__ and metadata lookup also fails, + resulting in packaging.version.InvalidVersion: Invalid version: 'N/A'.""" + try: + import gguf + + if not hasattr(gguf, "__version__"): + import importlib.metadata + + try: + gguf.__version__ = importlib.metadata.version("gguf") + except Exception: + gguf.__version__ = "0.0.0" + except ImportError: + pass + + @lru_cache_frozenset(maxsize=32) def get_config( model: str, @@ -294,6 +481,7 @@ def get_config( ): is_gguf = check_gguf_file(model) if is_gguf: + _ensure_gguf_version() kwargs["gguf_file"] = model model = Path(model).parent @@ -321,6 +509,33 @@ def get_config( config = _load_deepseek_v32_model( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + except KeyError as e: + # Transformers v5 may register a built-in config class that + # conflicts with sglang's custom one (e.g. NemotronHConfig + # doesn't handle '-' in hybrid_override_pattern). Fall back + # to loading the raw config dict and using sglang's class. + # Also handle deepseek_v32 which v5 doesn't recognize. + if "deepseek_v32" in str(e): + config = _load_deepseek_v32_model( + model, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + config_dict, _ = PretrainedConfig.get_config_dict( + model, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config = _CONFIG_REGISTRY[model_type].from_pretrained( + model, revision=revision, **kwargs + ) + else: + raise if ( config.architectures is not None @@ -509,6 +724,12 @@ def get_tokenizer( if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + elif tokenizer_mode == "auto": + # In Transformers v5, the default for use_fast changed from True to False. + # Explicitly set use_fast=True for "auto" mode to maintain previous behavior + # and avoid issues with models that have incorrect tokenizer_class values. + if "use_fast" not in kwargs: + kwargs["use_fast"] = True # TODO(Xinyuan): Remove this once we have a proper tokenizer for Devstral if tokenizer_name == "mistralai/Devstral-Small-2505": @@ -516,6 +737,7 @@ def get_tokenizer( is_gguf = check_gguf_file(tokenizer_name) if is_gguf: + _ensure_gguf_version() kwargs["gguf_file"] = tokenizer_name tokenizer_name = Path(tokenizer_name).parent @@ -565,17 +787,218 @@ def get_tokenizer( else: raise e + # Transformers v5 may silently fall back to a generic TokenizersBackend + # when trust_remote_code=False and the model requires a custom tokenizer. + # Detect this and auto-retry with trust_remote_code=True. + if not trust_remote_code and type(tokenizer).__name__ == "TokenizersBackend": + logger.info( + "Detected generic TokenizersBackend for %s, " + "retrying with trust_remote_code=True", + tokenizer_name, + ) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=True, + tokenizer_revision=tokenizer_revision, + clean_up_tokenization_spaces=False, + **kwargs, + ) + + _fix_v5_tokenizer_components(tokenizer, tokenizer_name, tokenizer_revision) + _fix_v5_add_bos_eos_token(tokenizer, tokenizer_name, tokenizer_revision) + if not isinstance(tokenizer, PreTrainedTokenizerFast): warnings.warn( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) + _fix_special_tokens_pattern(tokenizer) attach_additional_stop_token_ids(tokenizer) tokenizer = patch_tokenizer(tokenizer) return tokenizer +def _fix_v5_tokenizer_components(tokenizer, model_name_or_path, revision=None): + """Fix pre_tokenizer/decoder when a v5 tokenizer class overwrites them. + + In transformers v5, some tokenizer classes (e.g. LlamaTokenizer) have a + custom __init__ that rebuilds the pre_tokenizer and decoder from scratch + with class-specific components, discarding the originals from tokenizer.json. + This breaks models that specify LlamaTokenizerFast but actually use a + different tokenizer architecture (e.g. DeepSeek-V3.2 uses ByteLevel). + + Detects the mismatch by comparing against the raw tokenizer.json and + restores the original components when they differ. + """ + backend = getattr(tokenizer, "_tokenizer", None) + if backend is None: + return + + try: + from huggingface_hub import hf_hub_download + from tokenizers import Tokenizer as RawTokenizer + + tok_file = hf_hub_download( + model_name_or_path, + "tokenizer.json", + revision=revision, + local_files_only=True, + ) + raw = RawTokenizer.from_file(tok_file) + except Exception: + return + + raw_pre = type(raw.pre_tokenizer).__name__ if raw.pre_tokenizer else None + loaded_pre = type(backend.pre_tokenizer).__name__ if backend.pre_tokenizer else None + + if raw_pre and loaded_pre and raw_pre != loaded_pre: + logger.info( + "Fixing v5 tokenizer component mismatch for %s: " + "pre_tokenizer %s -> %s, decoder %s -> %s", + model_name_or_path, + loaded_pre, + raw_pre, + type(backend.decoder).__name__ if backend.decoder else None, + type(raw.decoder).__name__ if raw.decoder else None, + ) + backend.pre_tokenizer = raw.pre_tokenizer + backend.decoder = raw.decoder + + +def _fix_v5_add_bos_eos_token(tokenizer, model_name_or_path, revision=None): + """Restore add_bos_token/add_eos_token stripped by transformers v5. + + In transformers v5, _from_pretrained() strips add_bos_token and + add_eos_token from init kwargs when a tokenizer.json file is present, + assuming the tokenizer.json post-processor handles BOS/EOS addition. + However, many models (e.g. DeepSeek-V3) have a tokenizer.json whose + post-processor does NOT add BOS/EOS, and rely on the add_bos_token flag + from tokenizer_config.json instead. This causes silent accuracy regressions. + + This function reads the tokenizer_config.json and restores the values. + """ + try: + local_path = Path(model_name_or_path) / "tokenizer_config.json" + if local_path.is_file(): + config_file = str(local_path) + else: + from huggingface_hub import hf_hub_download + + config_file = hf_hub_download( + model_name_or_path, + "tokenizer_config.json", + revision=revision, + local_files_only=True, + ) + + with open(config_file) as f: + config = json.load(f) + except Exception as e: + logger.debug( + "_fix_v5_add_bos_eos_token: could not read tokenizer_config.json " + "for %s: %s", + model_name_or_path, + e, + ) + return + + changed = False + for attr in ("add_bos_token", "add_eos_token"): + if attr not in config: + continue + config_val = config[attr] + current_val = getattr(tokenizer, attr, None) + if current_val != config_val: + logger.info( + "Restoring %s=%s for %s (was %s after v5 loading)", + attr, + config_val, + model_name_or_path, + current_val, + ) + setattr(tokenizer, f"_{attr}", config_val) + changed = True + + # Rebuild the post-processor so it respects the restored flags + if changed and hasattr(tokenizer, "update_post_processor"): + tokenizer.update_post_processor() + + +def _fix_special_tokens_pattern(tokenizer): + """Fix https://github.com/huggingface/transformers/pull/42563 which defaults + special_tokens_pattern to "cls_sep", inserting None into token IDs when + cls_token/sep_token are undefined (e.g. Kimi-VL's TikTokenTokenizer). + """ + pattern = getattr(tokenizer, "special_tokens_pattern", None) + if pattern == "cls_sep" and ( + tokenizer.cls_token_id is None or tokenizer.sep_token_id is None + ): + tokenizer.special_tokens_pattern = "none" + + +def _fix_added_tokens_encoding(tokenizer): + """Ensure special tokens encode as single tokens in transformers v5. + + Some model tokenizers (e.g. MiniCPM-V-4) define special tokens like , + as attributes on the tokenizer class with corresponding IDs in the + vocabulary (via tokenizer.json's added_tokens). In transformers v5, these + tokens may not appear in get_added_vocab() and encode() splits them into + subwords, breaking multimodal pipelines that rely on finding them in input_ids. + + This function discovers such tokens by scanning tokenizer attributes, checks + if they encode correctly, and re-registers any that don't. + """ + # Discover special token strings from tokenizer attributes. + # Model tokenizers (e.g. MiniCPMVTokenizerFast) store them as attributes + # like im_start="", slice_start="", etc. + candidates = {} + for attr in dir(tokenizer): + if attr.startswith("_"): + continue + try: + val = getattr(tokenizer, attr) + except Exception: + continue + if ( + not isinstance(val, str) + or not val.startswith("<") + or not val.endswith(">") + or len(val) > 20 + ): + continue + token_id = tokenizer.convert_tokens_to_ids(val) + if token_id is not None and token_id != tokenizer.unk_token_id: + candidates[val] = token_id + + if not candidates: + return + + # Check which tokens fail to encode as single tokens. + broken = [] + for token_str, expected_id in candidates.items(): + try: + ids = tokenizer.encode(token_str, add_special_tokens=False) + if len(ids) != 1 or ids[0] != expected_id: + broken.append(token_str) + except Exception: + broken.append(token_str) + + if not broken: + return + + from transformers import AddedToken + + tokens_to_add = [AddedToken(tok, special=True, normalized=False) for tok in broken] + tokenizer.add_tokens(tokens_to_add, special_tokens=True) + logger.info( + "Re-registered %d special tokens for correct v5 encoding: %s", + len(broken), + broken[:10], + ) + + # Some models doesn't have an available processor, e.g.: InternVL def get_tokenizer_from_processor(processor): if isinstance(processor, PreTrainedTokenizerBase): @@ -583,6 +1006,71 @@ def get_tokenizer_from_processor(processor): return processor.tokenizer +def _build_processor_manually( + model_path, config, trust_remote_code, revision, **kwargs +): + """Build processor when AutoProcessor fails to resolve feature_extractor_type. + + In transformers v5, AutoProcessor.from_pretrained calls + AutoFeatureExtractor.from_pretrained which fails if + preprocessor_config.json lacks 'feature_extractor_type'. This loads the + processor class from the hub and constructs it with individually-loaded + components. + """ + import transformers + from transformers import AutoImageProcessor, AutoTokenizer + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + # Resolve processor class from auto_map — check both the model config + # and the preprocessor_config.json (some models like MiniCPM-o only + # declare AutoProcessor in the latter). + auto_map = getattr(config, "auto_map", None) or {} + proc_ref = auto_map.get("AutoProcessor") + if not proc_ref: + try: + pp_config_path = snapshot_download( + model_path, + allow_patterns=["preprocessor_config.json"], + revision=revision, + ) + pp_file = os.path.join(pp_config_path, "preprocessor_config.json") + if os.path.isfile(pp_file): + with open(pp_file) as f: + pp_auto_map = json.load(f).get("auto_map", {}) + proc_ref = pp_auto_map.get("AutoProcessor") + except Exception: + pass + if not proc_ref: + raise ValueError(f"Cannot determine processor class for {model_path}") + + proc_cls = get_class_from_dynamic_module( + proc_ref, model_path, code_revision=revision + ) + + # Load sub-components individually (these succeed) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=trust_remote_code, revision=revision + ) + init_kwargs = {"tokenizer": tokenizer} + + if "image_processor" in getattr(proc_cls, "attributes", []): + try: + init_kwargs["image_processor"] = AutoImageProcessor.from_pretrained( + model_path, trust_remote_code=trust_remote_code, revision=revision + ) + except Exception: + pass + + # Instantiate feature extractor from its declared class + fe_class_name = getattr(proc_cls, "feature_extractor_class", None) + if fe_class_name: + fe_class = getattr(transformers, fe_class_name, None) + if fe_class is not None: + init_kwargs["feature_extractor"] = fe_class() + + return proc_cls(**init_kwargs) + + def get_processor( tokenizer_name: str, *args, @@ -667,10 +1155,25 @@ def get_processor( revision=revision, **kwargs, ) + elif "Unrecognized feature extractor" in error_message: + logger.info( + "AutoProcessor failed on feature extractor for %s, " + "constructing processor manually", + tokenizer_name, + ) + processor = _build_processor_manually( + tokenizer_name, + config, + trust_remote_code, + revision, + **kwargs, + ) else: raise e tokenizer = get_tokenizer_from_processor(processor) + _fix_special_tokens_pattern(tokenizer) + _fix_added_tokens_encoding(tokenizer) attach_additional_stop_token_ids(tokenizer) return processor diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 1f10c0cd06bf..36a408fb05b4 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -25,7 +25,7 @@ AutoConfig, AutoModel, AutoModelForCausalLM, - AutoModelForVision2Seq, + AutoModelForImageTextToText, AutoProcessor, GenerationConfig, ) @@ -108,12 +108,26 @@ def _get_sentence_transformer_embedding_model( model = SentenceTransformer( model_path, model_kwargs={"torch_dtype": torch_dtype}, + # Force causal attention to match SGLang's RadixAttention behavior. + # In transformers v5, models with config.is_causal=false use + # bidirectional attention, but SGLang always uses causal attention. + config_kwargs={"is_causal": True}, truncate_dim=matryoshka_dim, ) else: # if no pre-trained sentence-transformers model from sentence_transformers import models word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype) + # In transformers v5, composite configs (e.g. Qwen2VLConfig) may not + # expose hidden_size at the top level. Patch it from the text sub-config + # so sentence_transformers' get_word_embedding_dimension() works. + _cfg = word_embedding_model.auto_model.config + if not hasattr(_cfg, "hidden_size"): + for _sub_attr in ("text_config", "language_config", "llm_config"): + _sub = getattr(_cfg, _sub_attr, None) + if _sub and hasattr(_sub, "hidden_size"): + _cfg.hidden_size = _sub.hidden_size + break pooling_model = models.Pooling( word_embedding_model.get_word_embedding_dimension(), pooling_mode="lasttoken", @@ -274,7 +288,7 @@ def start_model_process( ).to(get_device()) elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): - self.model = AutoModelForVision2Seq.from_pretrained( + self.model = AutoModelForImageTextToText.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=False, @@ -338,20 +352,18 @@ def start_model_process( images=image[0], return_tensors="pt" ) logits = self.model.get_image_features( - pixel_values=inputs.data["pixel_values"].to( - get_device() - ), - ).tolist() + pixel_values=inputs.data["pixel_values"].cuda(), + return_dict=True, + ).pooler_output.tolist() else: inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" ) logits = self.model.get_text_features( - input_ids=inputs.data["input_ids"].to(get_device()), - attention_mask=inputs.data["attention_mask"].to( - get_device() - ), - ).tolist() + input_ids=inputs.data["input_ids"].cuda(), + attention_mask=inputs.data["attention_mask"].cuda(), + return_dict=True, + ).pooler_output.tolist() else: logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) diff --git a/test/registered/core/test_score_api.py b/test/registered/core/test_score_api.py index 4110337ee419..5d15d564f139 100644 --- a/test/registered/core/test_score_api.py +++ b/test/registered/core/test_score_api.py @@ -85,7 +85,7 @@ def _get_token_ids(self, tokens): try: label_token_ids = [] for token in tokens: - encoding = tokenizer.encode_plus(token, add_special_tokens=False) + encoding = tokenizer(token, add_special_tokens=False) token_ids = encoding["input_ids"] label_token_ids.append(token_ids[0]) return label_token_ids diff --git a/test/registered/quant/test_awq.py b/test/registered/quant/test_awq.py index cc640b8515a2..005d7bd84ba3 100644 --- a/test/registered/quant/test_awq.py +++ b/test/registered/quant/test_awq.py @@ -13,7 +13,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=163, suite="stage-b-test-large-1-gpu") +register_cuda_ci(est_time=700, suite="stage-b-test-large-1-gpu") register_amd_ci(est_time=200, suite="stage-b-test-large-1-gpu-amd") diff --git a/test/registered/rl/test_multi_instance_release_memory_occupation.py b/test/registered/rl/test_multi_instance_release_memory_occupation.py index da66b6ca659c..172f9bf93bcd 100644 --- a/test/registered/rl/test_multi_instance_release_memory_occupation.py +++ b/test/registered/rl/test_multi_instance_release_memory_occupation.py @@ -1,3 +1,4 @@ +import gc import multiprocessing import os import time @@ -216,10 +217,20 @@ def _run_sglang_subprocess( # 5 - release hf model _mem_usage = get_gpu_memory_gb(rank) print(f"GPU{rank} Memory usage after resuming Sgl weights: {_mem_usage}") + # In transformers v5, from_pretrained with device_map attaches accelerate + # dispatch hooks that hold strong refs to parameters. Remove them first. + try: + from accelerate.hooks import remove_hook_from_submodules + + remove_hook_from_submodules(hf_model) + except (ImportError, Exception): + pass del hf_model hf_model = None + gc.collect() torch.cuda.empty_cache() time.sleep(3) + gc.collect() torch.cuda.empty_cache() _curr_usage = get_gpu_memory_gb(rank) assert ( diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index 5e8c1928b606..c917e1b8b972 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -1295,9 +1295,9 @@ def setUp(self): ), ] self.detector = DeepSeekV32Detector() - from transformers import AutoTokenizer + from sglang.srt.utils.hf_transformers_utils import get_tokenizer - self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") + self.tokenizer = get_tokenizer("deepseek-ai/DeepSeek-V3.2") self.interval = 1 def test_detect_and_parse_xml_format(self): diff --git a/test/registered/vlm/test_vlm_input_format.py b/test/registered/vlm/test_vlm_input_format.py index 57a2c9a98de4..92477c799039 100644 --- a/test/registered/vlm/test_vlm_input_format.py +++ b/test/registered/vlm/test_vlm_input_format.py @@ -35,6 +35,7 @@ def forward(self, x): from sglang import Engine from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.parser.conversation import generate_chat_conv +from sglang.srt.utils.hf_transformers_utils import _fix_added_tokens_encoding register_cuda_ci(est_time=447, suite="stage-b-test-large-1-gpu") @@ -61,6 +62,7 @@ def setUpClass(cls): cls.processor = AutoProcessor.from_pretrained( cls.model_path, trust_remote_code=True, use_fast=True ) + _fix_added_tokens_encoding(cls.processor.tokenizer) cls._init_visual() @classmethod @@ -199,16 +201,22 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC @classmethod def _init_visual(cls): - cls.visual_model = ( - Qwen2_5_VLForConditionalGeneration.from_pretrained( - cls.model_path, torch_dtype=torch.bfloat16 + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16 + ).eval() + # In transformers v5, .visual moved under .model + visual = model.model.visual + cls.visual_model = visual.to(cls.device) + + # In transformers v5, the visual encoder returns BaseModelOutputWithPooling; + # pooler_output has the spatially-merged embeddings we need. + def visual(processor_output): + out = cls.visual_model( + processor_output["pixel_values"], processor_output["image_grid_thw"] ) - .eval() - .visual.to(cls.device) - ) - cls.visual = lambda processor_output: cls.visual_model( - processor_output["pixel_values"], processor_output["image_grid_thw"] - ) + return out.pooler_output if hasattr(out, "pooler_output") else out + + cls.visual = visual def _processor_output_image_data(self, processor_output): return dict(processor_output, format="processor_output") @@ -251,13 +259,47 @@ class TestKimiVLImageUnderstandsImage( @classmethod def _init_visual(cls): - model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True) + import inspect + + from transformers import AutoConfig + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + config = AutoConfig.from_pretrained(cls.model_path, trust_remote_code=True) + + # Transformers v5 auto-populates rope_scaling with + # {"rope_theta": ..., "rope_type": "default"} even when the original + # config had rope_scaling: null. The remote KimiVL code branches on + # `if self.config.rope_scaling is None` so we must reset it. + tc = getattr(config, "text_config", None) + if tc is not None: + rs = getattr(tc, "rope_scaling", None) + if isinstance(rs, dict) and rs.get("rope_type") == "default": + tc.rope_scaling = None + + # Transformers v5 calls tie_weights(recompute_mapping=False) in + # post_init, but KimiVL's tie_weights doesn't accept that kwarg. + auto_map = getattr(config, "auto_map", {}) + model_ref = auto_map.get("AutoModel") + if model_ref: + model_cls = get_class_from_dynamic_module(model_ref, cls.model_path) + orig_tie = model_cls.tie_weights + if "recompute_mapping" not in inspect.signature(orig_tie).parameters: + + def _patched_tie(self, **kwargs): + return orig_tie(self) + + model_cls.tie_weights = _patched_tie + + model = AutoModel.from_pretrained( + cls.model_path, config=config, trust_remote_code=True + ) cls.vision_tower = model.vision_tower.eval().to(cls.device) cls.mm_projector = model.multi_modal_projector.eval().to(cls.device) + _vt_dtype = next(cls.vision_tower.parameters()).dtype cls.visual = lambda tokenizer_output: cls.mm_projector( cls.vision_tower( - pixel_values=tokenizer_output["pixel_values"], + pixel_values=tokenizer_output["pixel_values"].to(_vt_dtype), grid_hws=tokenizer_output["image_grid_hws"], ) ) @@ -376,9 +418,41 @@ def setUpClass(cls): @classmethod def _init_visual(cls): - model = AutoModel.from_pretrained( - cls.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 - ) + try: + model = AutoModel.from_pretrained( + cls.model_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=False, + ) + except RuntimeError as e: + if "meta" not in str(e): + raise + # Transformers v5 always uses meta tensors for init, which breaks + # models calling .item() in __init__ (e.g. InternVL's drop_path_rate). + # Fall back to from_config + manual weight loading. + import gc + import glob + import os + + from huggingface_hub import snapshot_download + from safetensors.torch import load_file + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(cls.model_path, trust_remote_code=True) + with torch.device("cpu"): + model = AutoModel.from_config( + config, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + model_dir = snapshot_download(cls.model_path) + for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))): + shard = load_file(f) + model.load_state_dict(shard, strict=False) + del shard + gc.collect() + cls.vision_model = model.vision_model.eval().to(cls.device) cls.mlp1 = model.mlp1.eval().to(cls.device) @@ -520,13 +594,44 @@ def setUpClass(cls): cls.processor = AutoProcessor.from_pretrained( cls.model_path, trust_remote_code=True ) + _fix_added_tokens_encoding(cls.processor.tokenizer) cls._init_visual() @classmethod def _init_visual(cls): - model = AutoModel.from_pretrained( - cls.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 - ) + try: + model = AutoModel.from_pretrained( + cls.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 + ) + except (AttributeError, RuntimeError) as e: + err = str(e) + if "all_tied_weights_keys" not in err and "meta" not in err: + raise + # Transformers v5: remote model code may lack all_tied_weights_keys + # or meta-tensor init may break .item() calls. Fall back to + # from_config + manual weight loading. + import gc + import glob + import os + + from huggingface_hub import snapshot_download + from safetensors.torch import load_file + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(cls.model_path, trust_remote_code=True) + with torch.device("cpu"): + model = AutoModel.from_config( + config, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + model_dir = snapshot_download(cls.model_path) + for f in sorted(glob.glob(os.path.join(model_dir, "*.safetensors"))): + shard = load_file(f) + model.load_state_dict(shard, strict=False) + del shard + gc.collect() + cls.vpm_model = model.vpm.eval().to(cls.device) cls.resampler_model = model.resampler.eval().to(cls.device) del model