diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 7598eb7e9089..a534be5a6fad 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -8,7 +8,6 @@ import json import logging import os -import re import tempfile from collections import defaultdict from typing import ( @@ -41,6 +40,12 @@ ModelOptFp4Config, ModelOptFp8Config, ) +from sglang.srt.model_loader.weight_validation import ( + _cleanup_corrupted_files_selective, + _cleanup_corrupted_model_cache, + _validate_safetensors_file, + _validate_sharded_model, +) from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once from sglang.utils import is_in_ci @@ -304,21 +309,31 @@ def find_local_hf_snapshot_dir( except Exception as e: logger.warning("Failed to find local snapshot in default HF cache: %s", e) - # if any incomplete file exists, force re-download by returning None + # Check for incomplete files and clean up if found if found_local_snapshot_dir: repo_folder = os.path.abspath( os.path.join(found_local_snapshot_dir, "..", "..") ) blobs_dir = os.path.join(repo_folder, "blobs") - if os.path.isdir(blobs_dir) and glob.glob( - os.path.join(blobs_dir, "*.incomplete") - ): + + # Check for incomplete download markers + incomplete_files = [] + if os.path.isdir(blobs_dir): + incomplete_files = glob.glob(os.path.join(blobs_dir, "*.incomplete")) + + if incomplete_files: logger.info( - "Found .incomplete files in %s for %s. " - "Considering local snapshot incomplete.", + "Found %d .incomplete files in %s for %s. " + "Will clean up and re-download.", + len(incomplete_files), blobs_dir, model_name_or_path, ) + _cleanup_corrupted_model_cache( + model_name_or_path, + found_local_snapshot_dir, + f"Incomplete download detected ({len(incomplete_files)} incomplete files)", + ) return None # if local snapshot exists, validate it contains at least one weight file @@ -344,45 +359,51 @@ def find_local_hf_snapshot_dir( ) local_weight_files = [] - # After we have a list of valid files, check for sharded model completeness. - # Check if all safetensors with name model-{i}-of-{n}.safetensors exists - checked_sharded_model = False - for f in local_weight_files: - if checked_sharded_model: - break - base_name = os.path.basename(f) - # Regex for files like model-00001-of-00009.safetensors - match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name) - if match: - prefix = match.group(1) - shard_id_str = match.group(2) - total_shards_str = match.group(3) - suffix = match.group(4) - total_shards = int(total_shards_str) - - # Check if all shards are present - missing_shards = [] - for i in range(1, total_shards + 1): - # Reconstruct shard name, preserving padding of original shard id - shard_name = ( - f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}" + # Validate sharded models and check for corruption + if local_weight_files: + is_valid, error_msg, corrupted_files = _validate_sharded_model( + found_local_snapshot_dir, local_weight_files + ) + if not is_valid: + if corrupted_files: + # Selective cleanup: only remove corrupted files + logger.info( + "Found %d corrupted file(s) for %s: %s. " + "Will selectively clean and re-download only these files.", + len(corrupted_files), + model_name_or_path, + error_msg, ) - expected_path = os.path.join(found_local_snapshot_dir, shard_name) - # os.path.exists returns False for broken symlinks, which is desired. - if not os.path.exists(expected_path): - missing_shards.append(shard_name) - - if missing_shards: + _cleanup_corrupted_files_selective(model_name_or_path, corrupted_files) + return None + else: + # Cannot selectively clean (e.g., missing shards) - remove entire cache logger.info( - "Found incomplete sharded model %s. Missing shards: %s. " - "Will attempt download.", + "Validation failed for %s: %s. " + "Will remove entire cache and re-download.", model_name_or_path, - missing_shards, + error_msg, + ) + _cleanup_corrupted_model_cache( + model_name_or_path, found_local_snapshot_dir, error_msg ) return None - # If we found and verified one set of shards, we are done. - checked_sharded_model = True + # Also validate single (non-sharded) safetensors files + for f in local_weight_files: + base_name = os.path.basename(f) + # Check if this is a single model file (not sharded) + if base_name in ["model.safetensors", "pytorch_model.safetensors"]: + if not _validate_safetensors_file(f): + logger.info( + "Corrupted model file %s for %s. " + "Will selectively clean and re-download this file.", + base_name, + model_name_or_path, + ) + # Selective cleanup for single file + _cleanup_corrupted_files_selective(model_name_or_path, [f]) + return None if len(local_weight_files) > 0: logger.info( diff --git a/python/sglang/srt/model_loader/weight_validation.py b/python/sglang/srt/model_loader/weight_validation.py new file mode 100644 index 000000000000..d83c5dae8744 --- /dev/null +++ b/python/sglang/srt/model_loader/weight_validation.py @@ -0,0 +1,220 @@ +import logging +import os +import re +import shutil +from typing import List, Optional, Tuple + +import safetensors + +logger = logging.getLogger(__name__) + + +def _validate_safetensors_file(file_path: str) -> bool: + """ + Validate that a safetensors file is readable and not corrupted. + + Args: + file_path: Path to the safetensors file + + Returns: + True if the file is valid, False if corrupted + """ + try: + # Attempt to open and read the header + # This will fail if the file is corrupted or incomplete + with safetensors.safe_open(file_path, framework="pt", device="cpu") as f: + # Just accessing the keys validates the header is readable + _ = list(f.keys()) + return True + except Exception as e: + logger.warning( + "Corrupted safetensors file detected: %s - %s: %s", + file_path, + type(e).__name__, + str(e), + ) + return False + + +def _validate_sharded_model( + snapshot_dir: str, weight_files: List[str] +) -> Tuple[bool, Optional[str], List[str]]: + """ + Validate that all model shards are present and not corrupted. + + Args: + snapshot_dir: Path to the model snapshot directory + weight_files: List of weight file paths + + Returns: + Tuple of (is_valid, error_message, corrupted_files) + - corrupted_files: List of file paths that are corrupted (for selective cleanup) + """ + # Pattern for sharded files: model-00001-of-00009.safetensors + shard_pattern = re.compile(r"(.*?)-(\d+)-of-(\d+)\.(safetensors|bin)") + + # Group files by shard pattern (prefix-*-of-N) + shard_groups = {} + for f in weight_files: + base_name = os.path.basename(f) + match = shard_pattern.match(base_name) + if match: + prefix = match.group(1) + total_shards_str = match.group(3) + suffix = match.group(4) + + group_key = f"{prefix}-of-{total_shards_str}.{suffix}" + if group_key not in shard_groups: + shard_groups[group_key] = { + "prefix": prefix, + "total": int(total_shards_str), + "suffix": suffix, + "found_shards": [], + "files": [], + } + + shard_id = int(match.group(2)) + shard_groups[group_key]["found_shards"].append(shard_id) + shard_groups[group_key]["files"].append(f) + + # Track corrupted files for selective cleanup + corrupted_files = [] + + # Validate each shard group + for group_key, group_info in shard_groups.items(): + total_shards = group_info["total"] + found_shards = set(group_info["found_shards"]) + expected_shards = set(range(1, total_shards + 1)) + + # Check for missing shards + missing_shards = expected_shards - found_shards + if missing_shards: + return ( + False, + f"Missing shards in {group_key}: {sorted(missing_shards)}", + [], + ) + + # Validate safetensors files for corruption + if group_info["suffix"] == "safetensors": + for f in group_info["files"]: + if not _validate_safetensors_file(f): + corrupted_files.append(f) + + # Check for required index file for safetensors shards + if group_info["suffix"] == "safetensors": + index_file = os.path.join( + snapshot_dir, f"{group_info['prefix']}.safetensors.index.json" + ) + if not os.path.exists(index_file): + return ( + False, + f"Missing index file: {os.path.basename(index_file)}", + [], + ) + + if corrupted_files: + return ( + False, + f"Corrupted shard files: {[os.path.basename(f) for f in corrupted_files]}", + corrupted_files, + ) + + return True, None, [] + + +def _cleanup_corrupted_files_selective( + model_name_or_path: str, corrupted_files: List[str] +) -> int: + """ + Selectively remove corrupted files and their blobs to force re-download. + + This is more efficient than removing the entire model cache as it only + re-downloads corrupted files rather than the entire model. + + Args: + model_name_or_path: Model identifier + corrupted_files: List of corrupted file paths (symlinks in snapshot) + + Returns: + Number of files successfully cleaned up + """ + cleaned_count = 0 + + for file_path in corrupted_files: + try: + # Resolve symlink to get blob path before deleting symlink + if os.path.islink(file_path): + blob_path = os.path.realpath(file_path) + + # Delete the symlink + os.remove(file_path) + logger.info( + "Removed corrupted symlink: %s", os.path.basename(file_path) + ) + + # Delete the blob (the actual corrupted data) + if os.path.exists(blob_path): + os.remove(blob_path) + logger.info( + "Removed corrupted blob: %s", os.path.basename(blob_path) + ) + + cleaned_count += 1 + elif os.path.exists(file_path): + # Not a symlink, just delete the file + os.remove(file_path) + logger.info("Removed corrupted file: %s", os.path.basename(file_path)) + cleaned_count += 1 + + except Exception as e: + logger.error( + "Failed to remove corrupted file %s: %s", + os.path.basename(file_path), + e, + ) + + if cleaned_count > 0: + logger.warning( + "Removed %d corrupted file(s) for %s. " + "These will be re-downloaded on next load.", + cleaned_count, + model_name_or_path, + ) + + return cleaned_count + + +def _cleanup_corrupted_model_cache( + model_name_or_path: str, snapshot_dir: str, reason: str +) -> None: + """ + Remove entire corrupted model cache directory to force a clean re-download. + + This is used when we cannot selectively clean (e.g., missing shards, incomplete + downloads with unknown affected files). + + Args: + model_name_or_path: Model identifier + snapshot_dir: Path to the snapshot directory + reason: Reason for cleanup + """ + # Navigate up to the model root directory: snapshots/hash -> snapshots -> model_root + repo_folder = os.path.abspath(os.path.join(snapshot_dir, "..", "..")) + + try: + logger.warning( + "Removing entire cache for %s at %s. Reason: %s", + model_name_or_path, + repo_folder, + reason, + ) + shutil.rmtree(repo_folder) + logger.info("Successfully removed corrupted cache directory") + except Exception as e: + logger.error( + "Failed to remove corrupted cache directory %s: %s. " + "Manual cleanup may be required.", + repo_folder, + e, + ) diff --git a/scripts/ci/prepare_runner.sh b/scripts/ci/prepare_runner.sh index 4e30d00acb57..fe0bf4400165 100755 --- a/scripts/ci/prepare_runner.sh +++ b/scripts/ci/prepare_runner.sh @@ -11,19 +11,4 @@ echo "" python3 "${SCRIPT_DIR}/cleanup_hf_cache.py" echo "" -# Validate model integrity for configured runners -echo "Validating model integrity..." - -# Enable accelerated HuggingFace downloads (10x faster on high-bandwidth networks) -export HF_HUB_ENABLE_HF_TRANSFER=1 - -python3 "${SCRIPT_DIR}/validate_and_download_models.py" -VALIDATION_EXIT_CODE=$? - -if [ $VALIDATION_EXIT_CODE -ne 0 ]; then - echo "Model validation failed with exit code: $VALIDATION_EXIT_CODE" - exit $VALIDATION_EXIT_CODE -fi - -echo "" echo "CI runner preparation complete!" diff --git a/scripts/ci/validate_and_download_models.py b/scripts/ci/validate_and_download_models.py deleted file mode 100755 index e60d6b10d96d..000000000000 --- a/scripts/ci/validate_and_download_models.py +++ /dev/null @@ -1,627 +0,0 @@ -#!/usr/bin/env python3 -""" -Validate model integrity for CI runners and download if needed. - -This script checks HuggingFace cache for model completeness and downloads -missing models. It exits with code 0 if models are present or successfully -downloaded (emitting a warning annotation if repairs were needed), and exits -with code 1 only if download attempts fail. -""" - -import os -import re -import shutil -import sys -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -try: - from huggingface_hub import constants, snapshot_download - - HF_HUB_AVAILABLE = True -except ImportError: - print( - "Warning: huggingface_hub not available. Install with: pip install huggingface_hub" - ) - HF_HUB_AVAILABLE = False - -try: - from safetensors import safe_open - - SAFETENSORS_AVAILABLE = True -except ImportError: - print("Warning: safetensors not available. Install with: pip install safetensors") - SAFETENSORS_AVAILABLE = False - - -# Mapping of runner labels to their required models -# Add new runner labels and models here as needed -RUNNER_LABEL_MODEL_MAP: Dict[str, List[str]] = { - "1-gpu-runner": [ - "Alibaba-NLP/gte-Qwen2-1.5B-instruct", - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - "deepseek-ai/DeepSeek-OCR", - "google/gemma-3-4b-it", - "intfloat/e5-mistral-7b-instruct", - "lmms-lab/llava-onevision-qwen2-0.5b-ov", - "lmsys/sglang-ci-dsv3-test", - "lmsys/sglang-EAGLE-llama2-chat-7B", - "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B", - "LxzGordon/URM-LLaMa-3.1-8B", - "marco/mcdse-2b-v1", - "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.1-8B-Instruct", - "mistralai/Mixtral-8x7B-Instruct-v0.1", - "moonshotai/Kimi-VL-A3B-Instruct", - "nvidia/NVIDIA-Nemotron-Nano-9B-v2", - "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8", - "openai/gpt-oss-20b", - "lmsys/gpt-oss-20b-bf16", - "OpenGVLab/InternVL2_5-2B", - "Qwen/Qwen1.5-MoE-A2.7B", - "Qwen/Qwen2.5-7B-Instruct", - "Qwen/Qwen3-8B", - "Qwen/Qwen3-Coder-30B-A3B-Instruct", - "Qwen/Qwen3-Embedding-8B", - "Qwen/QwQ-32B-AWQ", - "Qwen/Qwen3-30B-A3B", - "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8", - "lmms-lab/llava-onevision-qwen2-7b-ov", - # diffusion - "Qwen/Qwen-Image", - "Qwen/Qwen-Image-Edit", - "black-forest-labs/FLUX.1-dev", - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", - "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", - "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", - "Wan-AI/Wan2.2-TI2V-5B-Diffusers", - "Wan-AI/Wan2.2-I2V-A14B-Diffusers", - ], - "2-gpu-runner": [ - "mistralai/Mixtral-8x7B-Instruct-v0.1", - "moonshotai/Kimi-Linear-48B-A3B-Instruct", - "Qwen/Qwen2-57B-A14B-Instruct", - "Qwen/Qwen2.5-VL-7B-Instruct", - "Qwen/Qwen3-VL-30B-A3B-Instruct", - "neuralmagic/Qwen2-72B-Instruct-FP8", - "zai-org/GLM-4.5-Air-FP8", - ], - "8-gpu-h200": [ - "deepseek-ai/DeepSeek-V3-0324", - "deepseek-ai/DeepSeek-V3.2-Exp", - "moonshotai/Kimi-K2-Thinking", - ], - "8-gpu-b200": ["deepseek-ai/DeepSeek-V3.1", "deepseek-ai/DeepSeek-V3.2-Exp"], - "4-gpu-b200": ["nvidia/DeepSeek-V3-0324-FP4"], - "4-gpu-gb200": ["nvidia/DeepSeek-V3-0324-FP4"], - "4-gpu-h100": [ - "lmsys/sglang-ci-dsv3-test", - "lmsys/sglang-ci-dsv3-test-NextN", - "lmsys/gpt-oss-120b-bf16", - ], -} - - -def get_hf_cache_dir() -> str: - """Get the HuggingFace cache directory.""" - if HF_HUB_AVAILABLE: - return constants.HF_HUB_CACHE - - # Fallback to environment variable or default - hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) - return os.path.join(hf_home, "hub") - - -def get_model_cache_path(model_id: str, cache_dir: str) -> Optional[Path]: - """ - Find the model's cache directory in HuggingFace hub cache. - - Args: - model_id: Model identifier (e.g., "deepseek-ai/DeepSeek-V3-0324") - cache_dir: HuggingFace cache directory - - Returns: - Path to model's snapshot directory, or None if not found - """ - # Convert model_id to cache directory name format - # "deepseek-ai/DeepSeek-V3-0324" -> "models--deepseek-ai--DeepSeek-V3-0324" - cache_model_name = "models--" + model_id.replace("/", "--") - model_path = Path(cache_dir) / cache_model_name - - if not model_path.exists(): - return None - - # Find the most recent snapshot directory - snapshots_dir = model_path / "snapshots" - if not snapshots_dir.exists(): - return None - - # Get all snapshot directories (sorted by modification time, most recent first) - snapshot_dirs = sorted( - [d for d in snapshots_dir.iterdir() if d.is_dir()], - key=lambda x: x.stat().st_mtime, - reverse=True, - ) - - if not snapshot_dirs: - return None - - return snapshot_dirs[0] - - -def check_incomplete_files(model_path: Path, cache_dir: str) -> List[str]: - """ - Check for incomplete download marker files specific to this model. - - Args: - model_path: Path to model's snapshot directory - cache_dir: HuggingFace cache directory - - Returns: - List of incomplete files found for this specific model - """ - incomplete_in_snapshot = [] - - # Check if any files in the snapshot are symlinks to .incomplete blobs - # This ensures we only flag incomplete files for THIS specific model, - # not other models that might be downloading concurrently - # Use recursive glob to support Diffusers models with weights in subdirectories - for file_path in model_path.glob("**/*"): - if file_path.is_symlink(): - try: - target = file_path.resolve() - # Check if the symlink target has .incomplete suffix - if str(target).endswith(".incomplete"): - incomplete_in_snapshot.append(str(target)) - except (OSError, RuntimeError): - # Broken symlink - also indicates incomplete download - incomplete_in_snapshot.append(str(file_path)) - - return incomplete_in_snapshot - - -def validate_safetensors_file(file_path: Path) -> Tuple[bool, Optional[str]]: - """ - Validate that a safetensors file is readable and not corrupted. - - Args: - file_path: Path to the safetensors file - - Returns: - Tuple of (is_valid, error_message) - """ - if not SAFETENSORS_AVAILABLE: - # Skip validation if safetensors library is not available - return True, None - - try: - # Attempt to open and read the header - # This will fail if the file is corrupted or incomplete - with safe_open(file_path, framework="pt", device="cpu") as f: - # Just accessing the keys validates the header is readable - _ = f.keys() - return True, None - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - # Return detailed error for debugging - return False, f"{error_type}: {error_msg}" - - -def validate_model_shards(model_path: Path) -> Tuple[bool, Optional[str], List[Path]]: - """ - Validate that all model shards are present and complete. - - Args: - model_path: Path to model's snapshot directory - - Returns: - Tuple of (is_valid, error_message, corrupted_files) - - corrupted_files: List of paths to corrupted shard files that should be removed - """ - # Pattern for sharded files: model-00001-of-00009.safetensors, pytorch_model-00001-of-00009.bin, - # or diffusion_pytorch_model-00001-of-00009.safetensors (for Diffusers models) - # Use word boundary to prevent matching files like tokenizer_model-* or optimizer_model-* - shard_pattern = re.compile( - r"(?:^|/)(?:model|pytorch_model|diffusion_pytorch_model)-(\d+)-of-(\d+)\.(safetensors|bin)" - ) - - # Find all shard files recursively (both .safetensors and .bin) - # This supports both standard models (weights in root) and Diffusers models (weights in subdirs) - shard_files = list(model_path.glob("**/*-*-of-*.safetensors")) + list( - model_path.glob("**/*-*-of-*.bin") - ) - - if not shard_files: - # No sharded files - check for any safetensors or bin files recursively - # Exclude non-model files like tokenizer, config, optimizer, etc. - all_safetensors = list(model_path.glob("**/*.safetensors")) - all_bins = list(model_path.glob("**/*.bin")) - - # Filter out non-model files - excluded_prefixes = ["tokenizer", "optimizer", "training_", "config"] - single_files = [ - f - for f in (all_safetensors or all_bins) - if not any(f.name.startswith(prefix) for prefix in excluded_prefixes) - and not f.name.endswith(".index.json") - ] - - if single_files: - # Validate all safetensors files, not just the first one - for model_file in single_files: - if model_file.suffix == ".safetensors": - is_valid, error_msg = validate_safetensors_file(model_file) - if not is_valid: - return ( - False, - f"Corrupted file {model_file.name}: {error_msg}", - [model_file], - ) - return True, None, [] - return False, "No model weight files found (safetensors or bin)", [] - - # Group shards by subdirectory and total count - # This handles Diffusers models where different components (transformer/, vae/) - # have different numbers of shards - shard_groups = {} - for shard_file in shard_files: - # Match against the full path string to get proper path separation - match = shard_pattern.search(str(shard_file)) - if match: - shard_num = int(match.group(1)) - total = int(match.group(2)) - parent = shard_file.parent - key = (str(parent.relative_to(model_path)), total) - - if key not in shard_groups: - shard_groups[key] = set() - shard_groups[key].add(shard_num) - - if not shard_groups: - return False, "Could not determine shard groups from filenames", [] - - # Validate each group separately - for (parent_path, total_shards), found_shards in shard_groups.items(): - expected_shards = set(range(1, total_shards + 1)) - missing_shards = expected_shards - found_shards - - if missing_shards: - missing_list = sorted(missing_shards) - location = f" in {parent_path}" if parent_path != "." else "" - # Missing shards - nothing to remove, let download handle it - return ( - False, - f"Missing shards{location}: {missing_list} (expected {total_shards} total)", - [], - ) - - # Check for index file (look for specific patterns matching the shard prefixes) - # Standard models: model.safetensors.index.json or pytorch_model.safetensors.index.json - # Diffusers models: diffusion_pytorch_model.safetensors.index.json in subdirs - valid_index_patterns = [ - "model.safetensors.index.json", - "pytorch_model.safetensors.index.json", - "**/diffusion_pytorch_model.safetensors.index.json", - ] - - index_files = [] - for pattern in valid_index_patterns: - index_files.extend(model_path.glob(pattern)) - - if not index_files: - return ( - False, - "Missing required index file (model/pytorch_model/diffusion_pytorch_model.safetensors.index.json)", - [], - ) - - # Validate each safetensors shard file for corruption - print(f" Validating {len(shard_files)} shard file(s) for corruption...") - corrupted_files = [] - for shard_file in shard_files: - if shard_file.suffix == ".safetensors": - is_valid, error_msg = validate_safetensors_file(shard_file) - if not is_valid: - corrupted_files.append(shard_file) - print(f" ✗ Corrupted: {shard_file.name} - {error_msg}") - - if corrupted_files: - return ( - False, - f"Corrupted shards: {[f.name for f in corrupted_files]}", - corrupted_files, - ) - - return True, None, [] - - -def validate_model( - model_id: str, cache_dir: str -) -> Tuple[bool, Optional[str], List[Path]]: - """ - Validate a model's cache integrity. - - Args: - model_id: Model identifier - cache_dir: HuggingFace cache directory - - Returns: - Tuple of (is_valid, error_message, corrupted_files) - - corrupted_files: List of paths to corrupted files that should be removed - """ - print(f"Validating model: {model_id}") - - # Find model in cache - model_path = get_model_cache_path(model_id, cache_dir) - if model_path is None: - return False, "Model not found in cache", [] - - print(f" Found in cache: {model_path}") - - # Check for incomplete files - incomplete_files = check_incomplete_files(model_path, cache_dir) - if incomplete_files: - return ( - False, - f"Found incomplete download files: {len(incomplete_files)} files", - [], - ) - - # Validate shards - is_valid, error_msg, corrupted_files = validate_model_shards(model_path) - if not is_valid: - return False, error_msg, corrupted_files - - print(f" ✓ Model validated successfully") - return True, None, [] - - -def download_model(model_id: str, cache_dir: str, corrupted_files: List[Path]) -> bool: - """ - Download a model from HuggingFace. - - Completely removes the model cache directory before downloading to ensure a clean download. - - Args: - model_id: Model identifier - cache_dir: HuggingFace cache directory - corrupted_files: List of specific file paths that are corrupted (unused, kept for compatibility) - - Returns: - True if download succeeded, False otherwise - """ - if not HF_HUB_AVAILABLE: - print(f"ERROR: Cannot download model - huggingface_hub not available") - return False - - print(f"Downloading model: {model_id}") - - # Completely remove the model directory from cache - cache_model_name = "models--" + model_id.replace("/", "--") - model_cache_path = Path(cache_dir) / cache_model_name - - if model_cache_path.exists(): - print(f" Removing entire model directory: {model_cache_path}") - try: - shutil.rmtree(model_cache_path) - print(f" ✓ Successfully removed model directory") - except Exception as e: - print(f" ✗ Failed to remove model directory: {e}") - print(f" Attempting download anyway...") - else: - print(f" Model directory not found in cache (will download fresh)") - - print(f" Downloading from HuggingFace (this may take a while for large models)...") - - try: - snapshot_download( - repo_id=model_id, - allow_patterns=["*.safetensors", "*.bin", "*.json", "*.txt", "*.model"], - ignore_patterns=["*.msgpack", "*.h5", "*.ot"], # codespell:ignore ot - ) - print(f" ✓ Download completed: {model_id}") - return True - except Exception as e: - print(f" ✗ Download failed: {e}") - return False - - -def get_runner_labels() -> List[str]: - """ - Get the runner labels from environment variables. - - GitHub Actions doesn't expose runner labels directly as environment variables. - Workflows should set the RUNNER_LABELS environment variable with a comma-separated - list of labels (e.g., "self-hosted,8-gpu-h200,linux"). - - Returns: - List of runner labels, empty list if not set - """ - labels_str = os.environ.get("RUNNER_LABELS", "") - if not labels_str: - return [] - - # Split by comma and strip whitespace - return [label.strip() for label in labels_str.split(",") if label.strip()] - - -def should_validate_runner(runner_labels: List[str]) -> bool: - """ - Check if the runner should have model validation based on its labels. - - Args: - runner_labels: List of runner labels - - Returns: - True if any label matches a configured label in RUNNER_LABEL_MODEL_MAP - """ - if not runner_labels: - return False - - # Check if any label is in the configured map - return any(label in RUNNER_LABEL_MODEL_MAP for label in runner_labels) - - -def get_required_models(runner_labels: List[str]) -> List[str]: - """ - Get list of models required based on runner labels. - - Args: - runner_labels: List of runner labels (e.g., ["self-hosted", "8-gpu-h200", "linux"]) - - Returns: - List of model identifiers to validate (deduplicated) - """ - all_models = [] - - for label in runner_labels: - if label in RUNNER_LABEL_MODEL_MAP: - models = RUNNER_LABEL_MODEL_MAP[label] - print( - f" ✓ Matched label configuration: '{label}' -> {len(models)} model(s)" - ) - all_models.extend(models) - - if not all_models: - print(f" ⚠ No configuration found for any label in: {runner_labels}") - - # Remove duplicates while preserving order - seen = set() - unique_models = [] - for model in all_models: - if model not in seen: - seen.add(model) - unique_models.append(model) - - return unique_models - - -def main() -> int: - """ - Main validation logic. - - Returns: - 0 if all models are valid, successfully downloaded, or runner doesn't need validation - 1 only if download attempts fail - """ - print("=" * 70) - print("Model Validation for CI Runners") - print("=" * 70) - - runner_labels = get_runner_labels() - print(f"Runner labels: {', '.join(runner_labels) if runner_labels else 'NOT SET'}") - - # Check if this runner needs validation - if not should_validate_runner(runner_labels): - print( - "Skipping validation: No runner labels match configured model requirements" - ) - return 0 - - print(f"Proceeding with model validation for this runner") - - # Get required models for these runner labels - required_models = get_required_models(runner_labels) - - if not required_models: - print(f"Warning: No models configured for labels: {runner_labels}") - return 0 - - print(f"Models to validate: {required_models}") - print("-" * 70) - - # Get cache directory - cache_dir = get_hf_cache_dir() - print(f"HuggingFace cache: {cache_dir}") - print("-" * 70) - - # Track validation results - # Maps model_id -> (error_msg, corrupted_files) - models_needing_download: Dict[str, Tuple[str, List[Path]]] = {} - - # Validate each required model - for model_id in required_models: - is_valid, error_msg, corrupted_files = validate_model(model_id, cache_dir) - - if not is_valid: - print(f" ✗ Validation failed: {error_msg}") - models_needing_download[model_id] = (error_msg, corrupted_files) - - print("-" * 70) - - # If all models are valid, exit successfully - if not models_needing_download: - print("✓ All models validated successfully!") - return 0 - - # Models need to be downloaded - print(f"⚠ Cache validation failed for {len(models_needing_download)} model(s)") - for model_id, (error_msg, _) in models_needing_download.items(): - print(f" - {model_id}: {error_msg}") - - print("-" * 70) - print("Attempting to download missing/corrupted models...") - print("-" * 70) - - download_failed = False - for model_id, (error_msg, corrupted_files) in models_needing_download.items(): - if not download_model(model_id, cache_dir, corrupted_files): - download_failed = True - - print("-" * 70) - - if download_failed: - print("✗ FAILED: Some models could not be downloaded") - return 1 - - # All downloads succeeded - now validate them again - print("✓ All models downloaded successfully!") - print("-" * 70) - print("Validating downloaded models...") - print("-" * 70) - - validation_failed = False - for model_id in models_needing_download.keys(): - is_valid, error_msg, _ = validate_model(model_id, cache_dir) - if not is_valid: - print(f" ✗ Post-download validation failed for {model_id}: {error_msg}") - validation_failed = True - - print("-" * 70) - - if validation_failed: - print("✗ FAILED: Some models failed validation after download") - return 1 - - # All validations passed - emit warning but exit successfully - print("✓ All downloaded models validated successfully!") - print("⚠ WARNING: Models were missing/corrupted in cache and have been repaired.") - print(f" Repaired models: {', '.join(models_needing_download.keys())}") - - # Emit GitHub Actions warning annotation for visibility - print( - f"::warning file=scripts/ci/validate_and_download_models.py::" - f"Cache validation failed for {len(models_needing_download)} model(s). " - f"Models were re-downloaded and validated successfully. " - f"This may indicate cache corruption or infrastructure issues." - ) - - return 0 - - -if __name__ == "__main__": - try: - exit_code = main() - sys.exit(exit_code) - except KeyboardInterrupt: - print("\nInterrupted by user") - sys.exit(1) - except Exception as e: - print(f"ERROR: Unexpected error: {e}") - import traceback - - traceback.print_exc() - sys.exit(1)