Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 1 addition & 209 deletions python/sglang/srt/utils/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Utilities for Huggingface Transformers."""

import contextlib
import glob
import json
import logging
import os
Expand All @@ -23,7 +22,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

import huggingface_hub
import torch
from huggingface_hub import snapshot_download

Expand Down Expand Up @@ -69,14 +67,7 @@
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
from sglang.srt.utils import (
find_local_repo_dir,
is_remote_url,
logger,
lru_cache_frozenset,
mistral_utils,
)
from sglang.utils import is_in_ci
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils

_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
ChatGLMConfig,
Expand Down Expand Up @@ -408,197 +399,12 @@ def get_context_length(config):
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def _validate_tokenizer_file(file_path: str) -> bool:
"""
Validate that a tokenizer file is readable and not corrupted.

Args:
file_path: Path to the tokenizer file

Returns:
True if the file is valid, False if corrupted
"""
try:
# For JSON files, validate they're parseable
if file_path.endswith(".json"):
with open(file_path, "r") as f:
json.load(f)
return True
# For .model files (SentencePiece), just check readability
elif file_path.endswith(".model"):
with open(file_path, "rb") as f:
# Read first few bytes to verify file is readable
_ = f.read(100)
return True
# For other files, just check they exist and are readable
else:
with open(file_path, "rb") as f:
_ = f.read(100)
return True
except Exception as e:
logger.warning(
"Corrupted tokenizer file detected: %s - %s: %s",
file_path,
type(e).__name__,
str(e),
)
return False


def find_local_tokenizer_snapshot_dir(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> Optional[str]:
"""If the tokenizer files are already local, skip downloading and return the path.
Only applied in CI.
"""
if not is_in_ci():
return None

if os.path.isdir(model_name_or_path):
logger.info(
"Tokenizer path %s is already a local directory, skipping cache check",
model_name_or_path,
)
return None

logger.info("Checking for cached tokenizer: %s", model_name_or_path)
found_local_snapshot_dir = None

# Check custom cache_dir (if provided)
if cache_dir:
try:
repo_folder = os.path.join(
cache_dir,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *model_name_or_path.split("/")]
),
)
rev_to_use = revision
if not rev_to_use:
ref_main = os.path.join(repo_folder, "refs", "main")
if os.path.isfile(ref_main):
with open(ref_main) as f:
rev_to_use = f.read().strip()
if rev_to_use:
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
if os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning(
"Failed to find local snapshot in custom cache_dir %s: %s",
cache_dir,
e,
)

# Check default HF cache as well
if not found_local_snapshot_dir:
try:
rev_dir = find_local_repo_dir(model_name_or_path, revision)
if rev_dir and os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)

# If local snapshot exists, validate it contains at least one tokenizer file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None

# Layer 0: Check for incomplete files (corruption indicator)
repo_folder = os.path.abspath(os.path.join(found_local_snapshot_dir, "..", ".."))
blobs_dir = os.path.join(repo_folder, "blobs")
if os.path.isdir(blobs_dir) and glob.glob(os.path.join(blobs_dir, "*.incomplete")):
logger.info(
"Found .incomplete files in %s for %s. Considering local snapshot incomplete.",
blobs_dir,
model_name_or_path,
)
return None

local_tokenizer_files: List[str] = []
try:
for pattern in allow_patterns:
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
for f in matched_files:
# Layer 1: Check symlink target exists (broken symlink check)
if not os.path.exists(f):
continue
# Layer 2: Validate file content is not corrupted
if not _validate_tokenizer_file(f):
logger.info(
"Found corrupted tokenizer file %s for %s. Will re-download.",
f,
model_name_or_path,
)
return None
local_tokenizer_files.append(f)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_tokenizer_files = []

if len(local_tokenizer_files) > 0:
logger.info(
"Found local HF snapshot for tokenizer %s at %s; skipping download.",
model_name_or_path,
found_local_snapshot_dir,
)
return found_local_snapshot_dir
else:
logger.info(
"Local HF snapshot at %s has no files matching %s; will attempt download.",
found_local_snapshot_dir,
allow_patterns,
)
return None


# Filter warnings like: https://github.com/sgl-project/sglang/issues/8082
class TokenizerWarningsFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return "Calling super().encode with" not in record.getMessage()


def _check_tokenizer_cache(
tokenizer_name: str,
cache_dir: Optional[str],
revision: Optional[str],
include_processor_files: bool = False,
) -> str:
"""Check local cache for tokenizer files and return local path if found.

Args:
tokenizer_name: Model name or path
cache_dir: Optional custom cache directory
revision: Optional model revision
include_processor_files: Whether to include processor-specific files (*.py, preprocessor_config.json)

Returns:
Local path if found in cache, otherwise returns original tokenizer_name
"""
allow_patterns = [
"*.json",
"*.model",
"*.txt",
"tokenizer.model",
"tokenizer_config.json",
]
if include_processor_files:
allow_patterns.extend(["*.py", "preprocessor_config.json"])

local_path = find_local_tokenizer_snapshot_dir(
tokenizer_name, cache_dir, allow_patterns, revision
)
return local_path if local_path is not None else tokenizer_name


def get_tokenizer(
tokenizer_name: str,
*args,
Expand Down Expand Up @@ -635,11 +441,6 @@ def get_tokenizer(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = client.get_local_dir()

# Check if tokenizer files are already in local cache (CI only)
tokenizer_name = _check_tokenizer_cache(
tokenizer_name, kwargs.get("cache_dir"), tokenizer_revision
)

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down Expand Up @@ -706,15 +507,6 @@ def get_processor(
):
# pop 'revision' from kwargs if present.
revision = kwargs.pop("revision", tokenizer_revision)

# Check if processor/tokenizer files are already in local cache (CI only)
tokenizer_name = _check_tokenizer_cache(
tokenizer_name,
kwargs.get("cache_dir"),
revision,
include_processor_files=True,
)

if "mistral-large-3" in str(tokenizer_name).lower():
config = _load_mistral_large_3_for_causal_LM(
tokenizer_name,
Expand Down
Loading