diff --git a/.github/workflows/nightly-test-nvidia.yml b/.github/workflows/nightly-test-nvidia.yml index 37d364392483..db5ea27328e7 100644 --- a/.github/workflows/nightly-test-nvidia.yml +++ b/.github/workflows/nightly-test-nvidia.yml @@ -47,6 +47,9 @@ concurrency: group: nightly-test-nvidia-${{ github.ref }} cancel-in-progress: true +env: + SGLANG_IS_IN_CI: true + jobs: # General tests - 1 GPU nightly-test-general-1-gpu-runner: diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 64f1171105a3..efc2957b1125 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -25,6 +25,9 @@ concurrency: group: pr-test-${{ github.ref }} cancel-in-progress: true +env: + SGLANG_IS_IN_CI: true + jobs: # =============================================== check changes ==================================================== check-changes: diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 4b02500e8a89..0bbc1c02757f 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -47,7 +47,6 @@ _validate_sharded_model, ) from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once -from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -421,6 +420,55 @@ def find_local_hf_snapshot_dir( return None +def _validate_weights_after_download( + hf_folder: str, + allow_patterns: List[str], + model_name_or_path: str, +) -> None: + """Validate downloaded weight files to catch corruption early. + + This function validates safetensors files after download to catch + corruption issues (truncated downloads, network errors, etc.) before + model loading fails with cryptic errors. + + Args: + hf_folder: Path to the downloaded model folder + allow_patterns: Patterns used to match weight files + model_name_or_path: Model identifier for error messages + + Raises: + RuntimeError: If any weight files are corrupted + """ + import glob as glob_module + + # Find all weight files that were downloaded + weight_files: List[str] = [] + for pattern in allow_patterns: + weight_files.extend(glob_module.glob(os.path.join(hf_folder, pattern))) + + if not weight_files: + return # No weight files to validate + + # Validate safetensors files + corrupted_files = [] + for f in weight_files: + if f.endswith(".safetensors") and os.path.exists(f): + if not _validate_safetensors_file(f): + corrupted_files.append(os.path.basename(f)) + + if corrupted_files: + # Clean up corrupted files so next attempt re-downloads them + _cleanup_corrupted_files_selective( + model_name_or_path, + [os.path.join(hf_folder, f) for f in corrupted_files], + ) + raise RuntimeError( + f"Downloaded model files are corrupted for {model_name_or_path}: " + f"{corrupted_files}. The corrupted files have been removed. " + "Please retry to re-download the model." + ) + + def download_weights_from_hf( model_name_or_path: str, cache_dir: Optional[str], @@ -446,17 +494,19 @@ def download_weights_from_hf( str: The path to the downloaded model weights. """ - if is_in_ci(): - # If the weights are already local, skip downloading and returns the path. - # This is used to skip too-many Huggingface API calls in CI. - path = find_local_hf_snapshot_dir( - model_name_or_path, cache_dir, allow_patterns, revision - ) - if path is not None: - return path + # Always check for valid local cache first. + # This validates cached files and cleans up corrupted ones. + path = find_local_hf_snapshot_dir( + model_name_or_path, cache_dir, allow_patterns, revision + ) + if path is not None: + # Valid local cache found, skip download + return path + # In CI, skip HF API calls if we're in offline mode or want to avoid rate limits + # But we already checked for local cache above, so if we're here we need to download if not huggingface_hub.constants.HF_HUB_OFFLINE: - # Before we download we look at that is available: + # Before we download we look at what is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -480,6 +530,10 @@ def download_weights_from_hf( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) + + # Validate downloaded files to catch corruption early + _validate_weights_after_download(hf_folder, allow_patterns, model_name_or_path) + return hf_folder