Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json
import logging
import os
import re
import tempfile
from collections import defaultdict
from typing import (
Expand Down Expand Up @@ -41,6 +40,12 @@
ModelOptFp4Config,
ModelOptFp8Config,
)
from sglang.srt.model_loader.weight_validation import (
_cleanup_corrupted_files_selective,
_cleanup_corrupted_model_cache,
_validate_safetensors_file,
_validate_sharded_model,
)
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.utils import is_in_ci

Expand Down Expand Up @@ -304,21 +309,31 @@ def find_local_hf_snapshot_dir(
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)

# if any incomplete file exists, force re-download by returning None
# Check for incomplete files and clean up if found
if found_local_snapshot_dir:
repo_folder = os.path.abspath(
os.path.join(found_local_snapshot_dir, "..", "..")
)
blobs_dir = os.path.join(repo_folder, "blobs")
if os.path.isdir(blobs_dir) and glob.glob(
os.path.join(blobs_dir, "*.incomplete")
):

# Check for incomplete download markers
incomplete_files = []
if os.path.isdir(blobs_dir):
incomplete_files = glob.glob(os.path.join(blobs_dir, "*.incomplete"))

if incomplete_files:
logger.info(
"Found .incomplete files in %s for %s. "
"Considering local snapshot incomplete.",
"Found %d .incomplete files in %s for %s. "
"Will clean up and re-download.",
len(incomplete_files),
blobs_dir,
model_name_or_path,
)
_cleanup_corrupted_model_cache(
model_name_or_path,
found_local_snapshot_dir,
f"Incomplete download detected ({len(incomplete_files)} incomplete files)",
)
return None

# if local snapshot exists, validate it contains at least one weight file
Expand All @@ -344,45 +359,51 @@ def find_local_hf_snapshot_dir(
)
local_weight_files = []

# After we have a list of valid files, check for sharded model completeness.
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
checked_sharded_model = False
for f in local_weight_files:
if checked_sharded_model:
break
base_name = os.path.basename(f)
# Regex for files like model-00001-of-00009.safetensors
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
if match:
prefix = match.group(1)
shard_id_str = match.group(2)
total_shards_str = match.group(3)
suffix = match.group(4)
total_shards = int(total_shards_str)

# Check if all shards are present
missing_shards = []
for i in range(1, total_shards + 1):
# Reconstruct shard name, preserving padding of original shard id
shard_name = (
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
# Validate sharded models and check for corruption
if local_weight_files:
is_valid, error_msg, corrupted_files = _validate_sharded_model(
found_local_snapshot_dir, local_weight_files
)
if not is_valid:
if corrupted_files:
# Selective cleanup: only remove corrupted files
logger.info(
"Found %d corrupted file(s) for %s: %s. "
"Will selectively clean and re-download only these files.",
len(corrupted_files),
model_name_or_path,
error_msg,
)
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
# os.path.exists returns False for broken symlinks, which is desired.
if not os.path.exists(expected_path):
missing_shards.append(shard_name)

if missing_shards:
_cleanup_corrupted_files_selective(model_name_or_path, corrupted_files)
return None
else:
# Cannot selectively clean (e.g., missing shards) - remove entire cache
logger.info(
"Found incomplete sharded model %s. Missing shards: %s. "
"Will attempt download.",
"Validation failed for %s: %s. "
"Will remove entire cache and re-download.",
model_name_or_path,
missing_shards,
error_msg,
)
_cleanup_corrupted_model_cache(
model_name_or_path, found_local_snapshot_dir, error_msg
)
return None

# If we found and verified one set of shards, we are done.
checked_sharded_model = True
# Also validate single (non-sharded) safetensors files
for f in local_weight_files:
base_name = os.path.basename(f)
# Check if this is a single model file (not sharded)
if base_name in ["model.safetensors", "pytorch_model.safetensors"]:
if not _validate_safetensors_file(f):
logger.info(
"Corrupted model file %s for %s. "
"Will selectively clean and re-download this file.",
base_name,
model_name_or_path,
)
# Selective cleanup for single file
_cleanup_corrupted_files_selective(model_name_or_path, [f])
return None
Comment on lines +363 to +406
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following the refactoring of _validate_sharded_model into _validate_model_weights as suggested in my other comment, this block can be significantly simplified. The separate validation logic for sharded and single-file models is no longer needed here, as the new function handles both cases. This change reduces code duplication and improves readability.

    if local_weight_files:
        is_valid, error_msg = _validate_model_weights(
            found_local_snapshot_dir, local_weight_files
        )
        if not is_valid:
            logger.info(
                "Validation failed for %s: %s. Will clean up and re-download.",
                model_name_or_path,
                error_msg,
            )
            _cleanup_corrupted_model_cache(
                model_name_or_path, found_local_snapshot_dir, error_msg
            )
            return None


if len(local_weight_files) > 0:
logger.info(
Expand Down
220 changes: 220 additions & 0 deletions python/sglang/srt/model_loader/weight_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import logging
import os
import re
import shutil
from typing import List, Optional, Tuple

import safetensors

logger = logging.getLogger(__name__)


def _validate_safetensors_file(file_path: str) -> bool:
"""
Validate that a safetensors file is readable and not corrupted.

Args:
file_path: Path to the safetensors file

Returns:
True if the file is valid, False if corrupted
"""
try:
# Attempt to open and read the header
# This will fail if the file is corrupted or incomplete
with safetensors.safe_open(file_path, framework="pt", device="cpu") as f:
# Just accessing the keys validates the header is readable
_ = list(f.keys())
return True
except Exception as e:
logger.warning(
"Corrupted safetensors file detected: %s - %s: %s",
file_path,
type(e).__name__,
str(e),
)
return False


def _validate_sharded_model(
snapshot_dir: str, weight_files: List[str]
) -> Tuple[bool, Optional[str], List[str]]:
"""
Validate that all model shards are present and not corrupted.

Args:
snapshot_dir: Path to the model snapshot directory
weight_files: List of weight file paths

Returns:
Tuple of (is_valid, error_message, corrupted_files)
- corrupted_files: List of file paths that are corrupted (for selective cleanup)
"""
# Pattern for sharded files: model-00001-of-00009.safetensors
shard_pattern = re.compile(r"(.*?)-(\d+)-of-(\d+)\.(safetensors|bin)")

# Group files by shard pattern (prefix-*-of-N)
shard_groups = {}
for f in weight_files:
base_name = os.path.basename(f)
match = shard_pattern.match(base_name)
if match:
prefix = match.group(1)
total_shards_str = match.group(3)
suffix = match.group(4)

group_key = f"{prefix}-of-{total_shards_str}.{suffix}"
if group_key not in shard_groups:
shard_groups[group_key] = {
"prefix": prefix,
"total": int(total_shards_str),
"suffix": suffix,
"found_shards": [],
"files": [],
}

shard_id = int(match.group(2))
shard_groups[group_key]["found_shards"].append(shard_id)
shard_groups[group_key]["files"].append(f)

# Track corrupted files for selective cleanup
corrupted_files = []

# Validate each shard group
for group_key, group_info in shard_groups.items():
total_shards = group_info["total"]
found_shards = set(group_info["found_shards"])
expected_shards = set(range(1, total_shards + 1))

# Check for missing shards
missing_shards = expected_shards - found_shards
if missing_shards:
return (
False,
f"Missing shards in {group_key}: {sorted(missing_shards)}",
[],
)

# Validate safetensors files for corruption
if group_info["suffix"] == "safetensors":
for f in group_info["files"]:
if not _validate_safetensors_file(f):
corrupted_files.append(f)

# Check for required index file for safetensors shards
if group_info["suffix"] == "safetensors":
index_file = os.path.join(
snapshot_dir, f"{group_info['prefix']}.safetensors.index.json"
)
if not os.path.exists(index_file):
return (
False,
f"Missing index file: {os.path.basename(index_file)}",
[],
)

if corrupted_files:
return (
False,
f"Corrupted shard files: {[os.path.basename(f) for f in corrupted_files]}",
corrupted_files,
)

return True, None, []


def _cleanup_corrupted_files_selective(
model_name_or_path: str, corrupted_files: List[str]
) -> int:
"""
Selectively remove corrupted files and their blobs to force re-download.

This is more efficient than removing the entire model cache as it only
re-downloads corrupted files rather than the entire model.

Args:
model_name_or_path: Model identifier
corrupted_files: List of corrupted file paths (symlinks in snapshot)

Returns:
Number of files successfully cleaned up
"""
cleaned_count = 0

for file_path in corrupted_files:
try:
# Resolve symlink to get blob path before deleting symlink
if os.path.islink(file_path):
blob_path = os.path.realpath(file_path)

# Delete the symlink
os.remove(file_path)
logger.info(
"Removed corrupted symlink: %s", os.path.basename(file_path)
)

# Delete the blob (the actual corrupted data)
if os.path.exists(blob_path):
os.remove(blob_path)
logger.info(
"Removed corrupted blob: %s", os.path.basename(blob_path)
)

cleaned_count += 1
elif os.path.exists(file_path):
# Not a symlink, just delete the file
os.remove(file_path)
logger.info("Removed corrupted file: %s", os.path.basename(file_path))
cleaned_count += 1

except Exception as e:
logger.error(
"Failed to remove corrupted file %s: %s",
os.path.basename(file_path),
e,
)

if cleaned_count > 0:
logger.warning(
"Removed %d corrupted file(s) for %s. "
"These will be re-downloaded on next load.",
cleaned_count,
model_name_or_path,
)

return cleaned_count


def _cleanup_corrupted_model_cache(
model_name_or_path: str, snapshot_dir: str, reason: str
) -> None:
"""
Remove entire corrupted model cache directory to force a clean re-download.

This is used when we cannot selectively clean (e.g., missing shards, incomplete
downloads with unknown affected files).

Args:
model_name_or_path: Model identifier
snapshot_dir: Path to the snapshot directory
reason: Reason for cleanup
"""
# Navigate up to the model root directory: snapshots/hash -> snapshots -> model_root
repo_folder = os.path.abspath(os.path.join(snapshot_dir, "..", ".."))

try:
logger.warning(
"Removing entire cache for %s at %s. Reason: %s",
model_name_or_path,
repo_folder,
reason,
)
shutil.rmtree(repo_folder)
logger.info("Successfully removed corrupted cache directory")
except Exception as e:
logger.error(
"Failed to remove corrupted cache directory %s: %s. "
"Manual cleanup may be required.",
repo_folder,
e,
)
15 changes: 0 additions & 15 deletions scripts/ci/prepare_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,4 @@ echo ""
python3 "${SCRIPT_DIR}/cleanup_hf_cache.py"
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Nov 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we still need this? Can we remove the whole script?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its safer to keep this at CI start. cleanup_hf_cache.py cleans stale artifacts across all models in the cache at CI start. weight_utils.py only validates the specific model being loaded. cleanup_hf_cache.py removes .tmp and .lock files that weight_utils.py doesn't handle.

echo ""

# Validate model integrity for configured runners
echo "Validating model integrity..."

# Enable accelerated HuggingFace downloads (10x faster on high-bandwidth networks)
export HF_HUB_ENABLE_HF_TRANSFER=1

python3 "${SCRIPT_DIR}/validate_and_download_models.py"
VALIDATION_EXIT_CODE=$?

if [ $VALIDATION_EXIT_CODE -ne 0 ]; then
echo "Model validation failed with exit code: $VALIDATION_EXIT_CODE"
exit $VALIDATION_EXIT_CODE
fi

echo ""
echo "CI runner preparation complete!"
Loading
Loading