Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions .github/workflows/nightly-test-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,25 @@ jobs:
cd test
python3 run_suite.py --hw cuda --suite nightly-8-gpu-h200 --nightly --continue-on-error

- name: Run MiniMax-M2 nightly performance test
timeout-minutes: 180
env:
TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}
GPU_CONFIG: "8-gpu-h200"
run: |
rm -rf test/performance_profiles_minimax_m2/
cd test
python3 nightly/test_minimax_m2_perf.py

- name: Publish MiniMax-M2 traces to storage repo
env:
GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}
GITHUB_RUN_ID: ${{ github.run_id }}
GITHUB_RUN_NUMBER: ${{ github.run_number }}
run: |
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_minimax_m2

- name: Run Qwen3-235B nightly performance test
timeout-minutes: 180
env:
Expand Down Expand Up @@ -172,25 +191,6 @@ jobs:
run: |
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_glm_4_6

- name: Run MiniMax-M2 nightly performance test
timeout-minutes: 180
env:
TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}
GPU_CONFIG: "8-gpu-h200"
run: |
rm -rf test/performance_profiles_minimax_m2/
cd test
python3 nightly/test_minimax_m2_perf.py

- name: Publish MiniMax-M2 traces to storage repo
env:
GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}
GITHUB_RUN_ID: ${{ github.run_id }}
GITHUB_RUN_NUMBER: ${{ github.run_number }}
run: |
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_minimax_m2

# General tests - 8 GPU H20
nightly-test-general-8-gpu-h20:
if: github.repository == 'sgl-project/sglang' && (inputs.job_filter == '' || inputs.job_filter == 'all' || inputs.job_filter == 'nightly-test-general-8-gpu-h20')
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,17 @@ def _find_local_hf_snapshot_dir_unlocked(
)
return None
else:
# Cannot selectively clean (e.g., missing shards) - remove entire cache
# Missing shards (not corruption) - let snapshot_download handle it.
# IMPORTANT: Do NOT delete the entire cache here, as other processes
# (TP/EP ranks) may already be loading weights from these files.
# Deleting the cache while other processes are using it causes
# FileNotFoundError race conditions. Instead, just return None
# to trigger a download - snapshot_download will only fetch
# missing files without disturbing existing ones.
log_info_on_rank0(
logger,
f"Validation failed for {model_name_or_path}: {error_msg}. "
"Will remove entire cache and re-download.",
)
_cleanup_corrupted_model_cache(
model_name_or_path, found_local_snapshot_dir, error_msg
"Will attempt to download missing files.",
)
return None

Expand Down
27 changes: 27 additions & 0 deletions python/sglang/srt/model_loader/weight_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ def _check_index_files_exist(snapshot_dir: str) -> Tuple[bool, Optional[str]]:

for index_file in index_files:
index_path = os.path.join(snapshot_dir, index_file)

# Check if index file is a broken symlink (exists in listing but blob missing)
if os.path.islink(index_path) and not os.path.exists(index_path):
# Broken symlink - clean it up so download can proceed
try:
blob_path = os.path.realpath(index_path)
os.remove(index_path)
logger.warning(
"Removed broken index symlink: %s (blob missing)", index_file
)
# Also try to remove dangling blob reference if it somehow exists
if os.path.exists(blob_path):
os.remove(blob_path)
except Exception as e:
logger.error("Failed to remove broken symlink %s: %s", index_file, e)
return (
False,
f"Broken index file symlink: {index_file} (cleaned up, will re-download)",
)

try:
with open(index_path) as f:
index_data = json.load(f)
Expand All @@ -85,6 +105,13 @@ def _check_index_files_exist(snapshot_dir: str) -> Tuple[bool, Optional[str]]:
f"Missing {len(missing_files)} file(s) from index {index_file}: {missing_files[:3]}{'...' if len(missing_files) > 3 else ''}",
)

except FileNotFoundError as e:
# Index file was listed but can't be read - could be race condition or broken state
logger.warning("Failed to read index file %s: %s", index_file, e)
return (
False,
f"Index file {index_file} unreadable (will re-download)",
)
except Exception as e:
logger.warning("Failed to read index file %s: %s", index_file, e)
continue
Expand Down
186 changes: 186 additions & 0 deletions test/manual/test_weight_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Comment thread
alisonshao marked this conversation as resolved.
Unit tests for weight validation and cache cleanup logic.

Tests the fix for issue #14754 - ensuring that missing shards do not trigger
entire cache deletion, which can cause race conditions in multi-process scenarios.
"""

import json
import os
import struct
import tempfile
import unittest

from sglang.srt.model_loader.weight_validation import (
_check_index_files_exist,
_validate_sharded_model,
)


class TestWeightValidation(unittest.TestCase):
"""Tests for weight validation functions."""

def test_validate_sharded_model_missing_shard(self):
"""
Test that missing shards are detected correctly.

This is the core test for issue #14754 fix: when a shard is missing,
the validation should return is_valid=False with an error message
containing "Missing", but corrupted_files should be empty (indicating
this is a missing shard issue, not a corruption issue).

This distinction is critical because:
- Missing shards: should NOT delete cache (other processes may be using it)
- Corrupted files: should delete only the corrupted files selectively
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Create partial shards (missing shard 3)
for i in [1, 2]: # Missing shard 3
open(
os.path.join(tmpdir, f"model-0000{i}-of-00003.safetensors"), "w"
).close()

# Create index file
index_data = {
"weight_map": {
"layer1": "model-00001-of-00003.safetensors",
"layer2": "model-00002-of-00003.safetensors",
"layer3": "model-00003-of-00003.safetensors",
}
}
with open(os.path.join(tmpdir, "model.safetensors.index.json"), "w") as f:
json.dump(index_data, f)

weight_files = [
os.path.join(tmpdir, f"model-0000{i}-of-00003.safetensors")
for i in [1, 2]
]

is_valid, error_msg, corrupted_files = _validate_sharded_model(
tmpdir, weight_files
)

self.assertFalse(is_valid)
self.assertIn("Missing", error_msg)
# CRITICAL: corrupted_files should be empty for missing shards
# This is what prevents entire cache deletion
self.assertEqual(corrupted_files, [])

def test_validate_sharded_model_all_present(self):
"""Test that complete shards pass validation."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create all shards with valid safetensors header
for i in [1, 2, 3]:
filepath = os.path.join(tmpdir, f"model-0000{i}-of-00003.safetensors")
# Create a minimal valid safetensors file
# Header: 8 bytes for header size + JSON header
header = b'{"__metadata__":{}}'
header_size = len(header)
with open(filepath, "wb") as f:
f.write(struct.pack("<Q", header_size))
f.write(header)

# Create index file
index_data = {
"weight_map": {
"layer1": "model-00001-of-00003.safetensors",
"layer2": "model-00002-of-00003.safetensors",
"layer3": "model-00003-of-00003.safetensors",
}
}
with open(os.path.join(tmpdir, "model.safetensors.index.json"), "w") as f:
json.dump(index_data, f)

weight_files = [
os.path.join(tmpdir, f"model-0000{i}-of-00003.safetensors")
for i in [1, 2, 3]
]

is_valid, error_msg, corrupted_files = _validate_sharded_model(
tmpdir, weight_files
)

self.assertTrue(is_valid)
self.assertIsNone(error_msg)
self.assertEqual(corrupted_files, [])

def test_validate_sharded_model_corrupted_shard(self):
"""
Test that corrupted shards are detected and returned in corrupted_files.

This tests the other branch: when a file exists but is corrupted
(invalid safetensors format), it should be added to corrupted_files
so that selective cleanup can remove just that file.
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Create shard 1 as valid
filepath1 = os.path.join(tmpdir, "model-00001-of-00003.safetensors")
header = b'{"__metadata__":{}}'
with open(filepath1, "wb") as f:
f.write(struct.pack("<Q", len(header)))
f.write(header)

# Create shard 2 as corrupted (invalid header)
filepath2 = os.path.join(tmpdir, "model-00002-of-00003.safetensors")
with open(filepath2, "wb") as f:
f.write(b"invalid data that is not a valid safetensors file")

# Create shard 3 as valid
filepath3 = os.path.join(tmpdir, "model-00003-of-00003.safetensors")
with open(filepath3, "wb") as f:
f.write(struct.pack("<Q", len(header)))
f.write(header)

# Create index file
index_data = {
"weight_map": {
"layer1": "model-00001-of-00003.safetensors",
"layer2": "model-00002-of-00003.safetensors",
"layer3": "model-00003-of-00003.safetensors",
}
}
with open(os.path.join(tmpdir, "model.safetensors.index.json"), "w") as f:
json.dump(index_data, f)

weight_files = [filepath1, filepath2, filepath3]

is_valid, error_msg, corrupted_files = _validate_sharded_model(
tmpdir, weight_files
)

self.assertFalse(is_valid)
self.assertIn("Corrupt", error_msg)
# The corrupted file should be identified
self.assertEqual(len(corrupted_files), 1)
self.assertIn("model-00002-of-00003.safetensors", corrupted_files[0])

def test_broken_index_symlink_detected(self):
"""
Test that broken index symlinks are detected and cause validation to fail.

When an index file is a symlink pointing to a non-existent blob,
validation should fail (to trigger re-download) rather than silently
continuing and causing timeout during actual loading.
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a broken symlink for the index file
index_path = os.path.join(tmpdir, "model.safetensors.index.json")
non_existent_blob = os.path.join(tmpdir, "blobs", "nonexistent_hash")
os.symlink(non_existent_blob, index_path)

# Verify it's a broken symlink
self.assertTrue(os.path.islink(index_path))
self.assertFalse(os.path.exists(index_path))

# Check should fail for broken symlink
is_valid, error_msg = _check_index_files_exist(tmpdir)

self.assertFalse(is_valid)
self.assertIn("Broken", error_msg)
# The broken symlink should have been cleaned up
self.assertFalse(os.path.exists(index_path))
self.assertFalse(os.path.islink(index_path))


if __name__ == "__main__":
unittest.main()
Loading