diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index ef04b85c6c7c..6be8dc42c241 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -14,6 +14,7 @@ """Utilities for Huggingface Transformers.""" import contextlib +import glob import json import logging import os @@ -22,6 +23,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union +import huggingface_hub import torch from huggingface_hub import snapshot_download @@ -67,7 +69,14 @@ from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR -from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils +from sglang.srt.utils import ( + find_local_repo_dir, + is_remote_url, + logger, + lru_cache_frozenset, + mistral_utils, +) +from sglang.utils import is_in_ci _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ ChatGLMConfig, @@ -399,12 +408,197 @@ def get_context_length(config): _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" +def _validate_tokenizer_file(file_path: str) -> bool: + """ + Validate that a tokenizer file is readable and not corrupted. + + Args: + file_path: Path to the tokenizer file + + Returns: + True if the file is valid, False if corrupted + """ + try: + # For JSON files, validate they're parseable + if file_path.endswith(".json"): + with open(file_path, "r") as f: + json.load(f) + return True + # For .model files (SentencePiece), just check readability + elif file_path.endswith(".model"): + with open(file_path, "rb") as f: + # Read first few bytes to verify file is readable + _ = f.read(100) + return True + # For other files, just check they exist and are readable + else: + with open(file_path, "rb") as f: + _ = f.read(100) + return True + except Exception as e: + logger.warning( + "Corrupted tokenizer file detected: %s - %s: %s", + file_path, + type(e).__name__, + str(e), + ) + return False + + +def find_local_tokenizer_snapshot_dir( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> Optional[str]: + """If the tokenizer files are already local, skip downloading and return the path. + Only applied in CI. + """ + if not is_in_ci(): + return None + + if os.path.isdir(model_name_or_path): + logger.info( + "Tokenizer path %s is already a local directory, skipping cache check", + model_name_or_path, + ) + return None + + logger.info("Checking for cached tokenizer: %s", model_name_or_path) + found_local_snapshot_dir = None + + # Check custom cache_dir (if provided) + if cache_dir: + try: + repo_folder = os.path.join( + cache_dir, + huggingface_hub.constants.REPO_ID_SEPARATOR.join( + ["models", *model_name_or_path.split("/")] + ), + ) + rev_to_use = revision + if not rev_to_use: + ref_main = os.path.join(repo_folder, "refs", "main") + if os.path.isfile(ref_main): + with open(ref_main) as f: + rev_to_use = f.read().strip() + if rev_to_use: + rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use) + if os.path.isdir(rev_dir): + found_local_snapshot_dir = rev_dir + except Exception as e: + logger.warning( + "Failed to find local snapshot in custom cache_dir %s: %s", + cache_dir, + e, + ) + + # Check default HF cache as well + if not found_local_snapshot_dir: + try: + rev_dir = find_local_repo_dir(model_name_or_path, revision) + if rev_dir and os.path.isdir(rev_dir): + found_local_snapshot_dir = rev_dir + except Exception as e: + logger.warning("Failed to find local snapshot in default HF cache: %s", e) + + # If local snapshot exists, validate it contains at least one tokenizer file + # matching allow_patterns before skipping download. + if found_local_snapshot_dir is None: + return None + + # Layer 0: Check for incomplete files (corruption indicator) + repo_folder = os.path.abspath(os.path.join(found_local_snapshot_dir, "..", "..")) + blobs_dir = os.path.join(repo_folder, "blobs") + if os.path.isdir(blobs_dir) and glob.glob(os.path.join(blobs_dir, "*.incomplete")): + logger.info( + "Found .incomplete files in %s for %s. Considering local snapshot incomplete.", + blobs_dir, + model_name_or_path, + ) + return None + + local_tokenizer_files: List[str] = [] + try: + for pattern in allow_patterns: + matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern)) + for f in matched_files: + # Layer 1: Check symlink target exists (broken symlink check) + if not os.path.exists(f): + continue + # Layer 2: Validate file content is not corrupted + if not _validate_tokenizer_file(f): + logger.info( + "Found corrupted tokenizer file %s for %s. Will re-download.", + f, + model_name_or_path, + ) + return None + local_tokenizer_files.append(f) + except Exception as e: + logger.warning( + "Failed to scan local snapshot %s with patterns %s: %s", + found_local_snapshot_dir, + allow_patterns, + e, + ) + local_tokenizer_files = [] + + if len(local_tokenizer_files) > 0: + logger.info( + "Found local HF snapshot for tokenizer %s at %s; skipping download.", + model_name_or_path, + found_local_snapshot_dir, + ) + return found_local_snapshot_dir + else: + logger.info( + "Local HF snapshot at %s has no files matching %s; will attempt download.", + found_local_snapshot_dir, + allow_patterns, + ) + return None + + # Filter warnings like: https://github.com/sgl-project/sglang/issues/8082 class TokenizerWarningsFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: return "Calling super().encode with" not in record.getMessage() +def _check_tokenizer_cache( + tokenizer_name: str, + cache_dir: Optional[str], + revision: Optional[str], + include_processor_files: bool = False, +) -> str: + """Check local cache for tokenizer files and return local path if found. + + Args: + tokenizer_name: Model name or path + cache_dir: Optional custom cache directory + revision: Optional model revision + include_processor_files: Whether to include processor-specific files (*.py, preprocessor_config.json) + + Returns: + Local path if found in cache, otherwise returns original tokenizer_name + """ + allow_patterns = [ + "*.json", + "*.model", + "*.txt", + "tokenizer.model", + "tokenizer_config.json", + ] + if include_processor_files: + allow_patterns.extend(["*.py", "preprocessor_config.json"]) + + local_path = find_local_tokenizer_snapshot_dir( + tokenizer_name, cache_dir, allow_patterns, revision + ) + return local_path if local_path is not None else tokenizer_name + + def get_tokenizer( tokenizer_name: str, *args, @@ -441,6 +635,11 @@ def get_tokenizer( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = client.get_local_dir() + # Check if tokenizer files are already in local cache (CI only) + tokenizer_name = _check_tokenizer_cache( + tokenizer_name, kwargs.get("cache_dir"), tokenizer_revision + ) + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -507,6 +706,15 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) + + # Check if processor/tokenizer files are already in local cache (CI only) + tokenizer_name = _check_tokenizer_cache( + tokenizer_name, + kwargs.get("cache_dir"), + revision, + include_processor_files=True, + ) + if "mistral-large-3" in str(tokenizer_name).lower(): config = _load_mistral_large_3_for_causal_LM( tokenizer_name,