-
Notifications
You must be signed in to change notification settings - Fork 5k
fix: checking if tokenizer is in cache before downloading from HF #14698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||||||||||||||
| """Utilities for Huggingface Transformers.""" | ||||||||||||||||||
|
|
||||||||||||||||||
| import contextlib | ||||||||||||||||||
| import glob | ||||||||||||||||||
| import json | ||||||||||||||||||
| import logging | ||||||||||||||||||
| import os | ||||||||||||||||||
|
|
@@ -22,6 +23,7 @@ | |||||||||||||||||
| from pathlib import Path | ||||||||||||||||||
| from typing import Any, Dict, List, Optional, Type, Union | ||||||||||||||||||
|
|
||||||||||||||||||
| import huggingface_hub | ||||||||||||||||||
| import torch | ||||||||||||||||||
| from huggingface_hub import snapshot_download | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -67,7 +69,14 @@ | |||||||||||||||||
| from sglang.srt.configs.internvl import InternVLChatConfig | ||||||||||||||||||
| from sglang.srt.connector import create_remote_connector | ||||||||||||||||||
| from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR | ||||||||||||||||||
| from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils | ||||||||||||||||||
| from sglang.srt.utils import ( | ||||||||||||||||||
| find_local_repo_dir, | ||||||||||||||||||
| is_remote_url, | ||||||||||||||||||
| logger, | ||||||||||||||||||
| lru_cache_frozenset, | ||||||||||||||||||
| mistral_utils, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from sglang.utils import is_in_ci | ||||||||||||||||||
|
|
||||||||||||||||||
| _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ | ||||||||||||||||||
| ChatGLMConfig, | ||||||||||||||||||
|
|
@@ -399,12 +408,197 @@ def get_context_length(config): | |||||||||||||||||
| _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _validate_tokenizer_file(file_path: str) -> bool: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Validate that a tokenizer file is readable and not corrupted. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| file_path: Path to the tokenizer file | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| True if the file is valid, False if corrupted | ||||||||||||||||||
| """ | ||||||||||||||||||
| try: | ||||||||||||||||||
| # For JSON files, validate they're parseable | ||||||||||||||||||
| if file_path.endswith(".json"): | ||||||||||||||||||
| with open(file_path, "r") as f: | ||||||||||||||||||
| json.load(f) | ||||||||||||||||||
| return True | ||||||||||||||||||
| # For .model files (SentencePiece), just check readability | ||||||||||||||||||
| elif file_path.endswith(".model"): | ||||||||||||||||||
| with open(file_path, "rb") as f: | ||||||||||||||||||
| # Read first few bytes to verify file is readable | ||||||||||||||||||
| _ = f.read(100) | ||||||||||||||||||
| return True | ||||||||||||||||||
| # For other files, just check they exist and are readable | ||||||||||||||||||
| else: | ||||||||||||||||||
| with open(file_path, "rb") as f: | ||||||||||||||||||
| _ = f.read(100) | ||||||||||||||||||
| return True | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning( | ||||||||||||||||||
| "Corrupted tokenizer file detected: %s - %s: %s", | ||||||||||||||||||
| file_path, | ||||||||||||||||||
| type(e).__name__, | ||||||||||||||||||
| str(e), | ||||||||||||||||||
| ) | ||||||||||||||||||
| return False | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def find_local_tokenizer_snapshot_dir( | ||||||||||||||||||
| model_name_or_path: str, | ||||||||||||||||||
| cache_dir: Optional[str], | ||||||||||||||||||
| allow_patterns: List[str], | ||||||||||||||||||
| revision: Optional[str] = None, | ||||||||||||||||||
| ) -> Optional[str]: | ||||||||||||||||||
| """If the tokenizer files are already local, skip downloading and return the path. | ||||||||||||||||||
| Only applied in CI. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if not is_in_ci(): | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| if os.path.isdir(model_name_or_path): | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| "Tokenizer path %s is already a local directory, skipping cache check", | ||||||||||||||||||
| model_name_or_path, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| logger.info("Checking for cached tokenizer: %s", model_name_or_path) | ||||||||||||||||||
| found_local_snapshot_dir = None | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check custom cache_dir (if provided) | ||||||||||||||||||
| if cache_dir: | ||||||||||||||||||
| try: | ||||||||||||||||||
| repo_folder = os.path.join( | ||||||||||||||||||
| cache_dir, | ||||||||||||||||||
| huggingface_hub.constants.REPO_ID_SEPARATOR.join( | ||||||||||||||||||
| ["models", *model_name_or_path.split("/")] | ||||||||||||||||||
| ), | ||||||||||||||||||
| ) | ||||||||||||||||||
| rev_to_use = revision | ||||||||||||||||||
| if not rev_to_use: | ||||||||||||||||||
| ref_main = os.path.join(repo_folder, "refs", "main") | ||||||||||||||||||
| if os.path.isfile(ref_main): | ||||||||||||||||||
| with open(ref_main) as f: | ||||||||||||||||||
| rev_to_use = f.read().strip() | ||||||||||||||||||
| if rev_to_use: | ||||||||||||||||||
| rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use) | ||||||||||||||||||
| if os.path.isdir(rev_dir): | ||||||||||||||||||
| found_local_snapshot_dir = rev_dir | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning( | ||||||||||||||||||
| "Failed to find local snapshot in custom cache_dir %s: %s", | ||||||||||||||||||
| cache_dir, | ||||||||||||||||||
| e, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check default HF cache as well | ||||||||||||||||||
| if not found_local_snapshot_dir: | ||||||||||||||||||
| try: | ||||||||||||||||||
| rev_dir = find_local_repo_dir(model_name_or_path, revision) | ||||||||||||||||||
| if rev_dir and os.path.isdir(rev_dir): | ||||||||||||||||||
| found_local_snapshot_dir = rev_dir | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning("Failed to find local snapshot in default HF cache: %s", e) | ||||||||||||||||||
|
|
||||||||||||||||||
| # If local snapshot exists, validate it contains at least one tokenizer file | ||||||||||||||||||
| # matching allow_patterns before skipping download. | ||||||||||||||||||
| if found_local_snapshot_dir is None: | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| # Layer 0: Check for incomplete files (corruption indicator) | ||||||||||||||||||
| repo_folder = os.path.abspath(os.path.join(found_local_snapshot_dir, "..", "..")) | ||||||||||||||||||
| blobs_dir = os.path.join(repo_folder, "blobs") | ||||||||||||||||||
| if os.path.isdir(blobs_dir) and glob.glob(os.path.join(blobs_dir, "*.incomplete")): | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| "Found .incomplete files in %s for %s. Considering local snapshot incomplete.", | ||||||||||||||||||
| blobs_dir, | ||||||||||||||||||
| model_name_or_path, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
| local_tokenizer_files: List[str] = [] | ||||||||||||||||||
| try: | ||||||||||||||||||
| for pattern in allow_patterns: | ||||||||||||||||||
| matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern)) | ||||||||||||||||||
| for f in matched_files: | ||||||||||||||||||
| # Layer 1: Check symlink target exists (broken symlink check) | ||||||||||||||||||
| if not os.path.exists(f): | ||||||||||||||||||
| continue | ||||||||||||||||||
| # Layer 2: Validate file content is not corrupted | ||||||||||||||||||
| if not _validate_tokenizer_file(f): | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| "Found corrupted tokenizer file %s for %s. Will re-download.", | ||||||||||||||||||
| f, | ||||||||||||||||||
| model_name_or_path, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return None | ||||||||||||||||||
| local_tokenizer_files.append(f) | ||||||||||||||||||
| except Exception as e: | ||||||||||||||||||
| logger.warning( | ||||||||||||||||||
| "Failed to scan local snapshot %s with patterns %s: %s", | ||||||||||||||||||
| found_local_snapshot_dir, | ||||||||||||||||||
| allow_patterns, | ||||||||||||||||||
| e, | ||||||||||||||||||
| ) | ||||||||||||||||||
| local_tokenizer_files = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| if len(local_tokenizer_files) > 0: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to make this condition much more restrictive. Otherwise, it will try to load corrupted weighs. Some ideas: check hash, check incomplete files, etc
sglang/scripts/ci/cleanup_hf_cache.py Lines 49 to 55 in 0f8bd55
|
||||||||||||||||||
| logger.info( | ||||||||||||||||||
| "Found local HF snapshot for tokenizer %s at %s; skipping download.", | ||||||||||||||||||
| model_name_or_path, | ||||||||||||||||||
| found_local_snapshot_dir, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return found_local_snapshot_dir | ||||||||||||||||||
| else: | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| "Local HF snapshot at %s has no files matching %s; will attempt download.", | ||||||||||||||||||
| found_local_snapshot_dir, | ||||||||||||||||||
| allow_patterns, | ||||||||||||||||||
| ) | ||||||||||||||||||
| return None | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # Filter warnings like: https://github.com/sgl-project/sglang/issues/8082 | ||||||||||||||||||
| class TokenizerWarningsFilter(logging.Filter): | ||||||||||||||||||
| def filter(self, record: logging.LogRecord) -> bool: | ||||||||||||||||||
| return "Calling super().encode with" not in record.getMessage() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def _check_tokenizer_cache( | ||||||||||||||||||
| tokenizer_name: str, | ||||||||||||||||||
| cache_dir: Optional[str], | ||||||||||||||||||
| revision: Optional[str], | ||||||||||||||||||
| include_processor_files: bool = False, | ||||||||||||||||||
| ) -> str: | ||||||||||||||||||
| """Check local cache for tokenizer files and return local path if found. | ||||||||||||||||||
|
|
||||||||||||||||||
| Args: | ||||||||||||||||||
| tokenizer_name: Model name or path | ||||||||||||||||||
| cache_dir: Optional custom cache directory | ||||||||||||||||||
| revision: Optional model revision | ||||||||||||||||||
| include_processor_files: Whether to include processor-specific files (*.py, preprocessor_config.json) | ||||||||||||||||||
|
|
||||||||||||||||||
| Returns: | ||||||||||||||||||
| Local path if found in cache, otherwise returns original tokenizer_name | ||||||||||||||||||
| """ | ||||||||||||||||||
| allow_patterns = [ | ||||||||||||||||||
| "*.json", | ||||||||||||||||||
| "*.model", | ||||||||||||||||||
| "*.txt", | ||||||||||||||||||
| "tokenizer.model", | ||||||||||||||||||
| "tokenizer_config.json", | ||||||||||||||||||
| ] | ||||||||||||||||||
| if include_processor_files: | ||||||||||||||||||
| allow_patterns.extend(["*.py", "preprocessor_config.json"]) | ||||||||||||||||||
|
|
||||||||||||||||||
| local_path = find_local_tokenizer_snapshot_dir( | ||||||||||||||||||
| tokenizer_name, cache_dir, allow_patterns, revision | ||||||||||||||||||
| ) | ||||||||||||||||||
| return local_path if local_path is not None else tokenizer_name | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_tokenizer( | ||||||||||||||||||
| tokenizer_name: str, | ||||||||||||||||||
| *args, | ||||||||||||||||||
|
|
@@ -441,6 +635,11 @@ def get_tokenizer( | |||||||||||||||||
| client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) | ||||||||||||||||||
| tokenizer_name = client.get_local_dir() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check if tokenizer files are already in local cache (CI only) | ||||||||||||||||||
| tokenizer_name = _check_tokenizer_cache( | ||||||||||||||||||
| tokenizer_name, kwargs.get("cache_dir"), tokenizer_revision | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| try: | ||||||||||||||||||
| tokenizer = AutoTokenizer.from_pretrained( | ||||||||||||||||||
| tokenizer_name, | ||||||||||||||||||
|
|
@@ -507,6 +706,15 @@ def get_processor( | |||||||||||||||||
| ): | ||||||||||||||||||
| # pop 'revision' from kwargs if present. | ||||||||||||||||||
| revision = kwargs.pop("revision", tokenizer_revision) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check if processor/tokenizer files are already in local cache (CI only) | ||||||||||||||||||
| tokenizer_name = _check_tokenizer_cache( | ||||||||||||||||||
| tokenizer_name, | ||||||||||||||||||
| kwargs.get("cache_dir"), | ||||||||||||||||||
| revision, | ||||||||||||||||||
| include_processor_files=True, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if "mistral-large-3" in str(tokenizer_name).lower(): | ||||||||||||||||||
| config = _load_mistral_large_3_for_causal_LM( | ||||||||||||||||||
| tokenizer_name, | ||||||||||||||||||
|
|
||||||||||||||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding three layers of Tokenizer validation: