diff --git a/python/sglang/srt/model_loader/weight_validation.py b/python/sglang/srt/model_loader/weight_validation.py index d83c5dae8744..3c145360d592 100644 --- a/python/sglang/srt/model_loader/weight_validation.py +++ b/python/sglang/srt/model_loader/weight_validation.py @@ -1,3 +1,4 @@ +import json import logging import os import re @@ -36,6 +37,61 @@ def _validate_safetensors_file(file_path: str) -> bool: return False +def _check_index_files_exist(snapshot_dir: str) -> Tuple[bool, Optional[str]]: + """ + Check if all files listed in safetensors index files actually exist on disk. + + This catches cases where the snapshot directory exists but files are missing + (e.g., due to incomplete downloads or corrupted cache). + + Args: + snapshot_dir: Path to the model snapshot directory + + Returns: + Tuple of (all_exist, error_message) + """ + # Find all safetensors index files + index_files = [ + f for f in os.listdir(snapshot_dir) if f.endswith(".safetensors.index.json") + ] + + if not index_files: + # No index files means it's not a sharded model, skip this check + return True, None + + for index_file in index_files: + index_path = os.path.join(snapshot_dir, index_file) + try: + with open(index_path) as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + continue + + # Check that all files in weight_map exist + required_files = set(weight_map.values()) + missing_files = [] + + for file_name in required_files: + file_path = os.path.join(snapshot_dir, file_name) + # Check both existence and that it's not a broken symlink + if not os.path.exists(file_path): + missing_files.append(file_name) + + if missing_files: + return ( + False, + f"Missing {len(missing_files)} file(s) from index {index_file}: {missing_files[:3]}{'...' if len(missing_files) > 3 else ''}", + ) + + except Exception as e: + logger.warning("Failed to read index file %s: %s", index_file, e) + continue + + return True, None + + def _validate_sharded_model( snapshot_dir: str, weight_files: List[str] ) -> Tuple[bool, Optional[str], List[str]]: @@ -50,6 +106,12 @@ def _validate_sharded_model( Tuple of (is_valid, error_message, corrupted_files) - corrupted_files: List of file paths that are corrupted (for selective cleanup) """ + # First, check if all files from the index actually exist + # This catches missing files that wouldn't be found by glob + index_check_valid, index_error = _check_index_files_exist(snapshot_dir) + if not index_check_valid: + return False, index_error, [] + # Pattern for sharded files: model-00001-of-00009.safetensors shard_pattern = re.compile(r"(.*?)-(\d+)-of-(\d+)\.(safetensors|bin)")