Skip to content

Commit 034a68e

Browse files
fzyzcjyxwu-intel
authored andcommitted
Add simple utility to dump tensors for debugging (sgl-project#6815)
1 parent e581330 commit 034a68e

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

python/sglang/srt/debug_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
import time
3+
from pathlib import Path
4+
5+
import torch
6+
7+
from sglang.srt.utils import get_bool_env_var
8+
9+
10+
class _Dumper:
11+
"""Utility to dump tensors, which can be useful when comparison checking models.
12+
13+
Example usage:
14+
debug_utils.dumper.dump("layer_start_hidden_states", hidden_states, layer_id=self.layer_id)
15+
"""
16+
17+
def __init__(self):
18+
self._enable = get_bool_env_var("SGLANG_DUMPER_ENABLE", "true")
19+
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
20+
self._enable_write_file = get_bool_env_var("SGLANG_DUMPER_WRITE_FILE", "1")
21+
self._partial_name = str(time.time())
22+
self.forward_pass_id = None
23+
24+
def dump(self, name, value, **kwargs):
25+
if not self._enable:
26+
return
27+
28+
from sglang.srt.distributed import get_tensor_model_parallel_rank
29+
30+
rank = get_tensor_model_parallel_rank()
31+
full_kwargs = dict(
32+
forward_pass_id=self.forward_pass_id,
33+
name=name,
34+
**kwargs,
35+
)
36+
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
37+
path = (
38+
self._base_dir / f"sglang_dump_{self._partial_name}_{rank}" / full_filename
39+
)
40+
41+
sample_value = self._get_sample_value(name, value)
42+
43+
print(
44+
f"[{rank}, {time.time()}] {path} "
45+
f"type={type(value)} "
46+
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
47+
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
48+
f"sample_value={sample_value}"
49+
)
50+
51+
if self._enable_write_file:
52+
path.parent.mkdir(parents=True, exist_ok=True)
53+
torch.save(value, str(path))
54+
55+
def _get_sample_value(self, name, value):
56+
if value is None:
57+
return None
58+
59+
if isinstance(value, tuple):
60+
return [self._get_sample_value(name, x) for x in value]
61+
62+
if not isinstance(value, torch.Tensor):
63+
return None
64+
65+
if value.numel() < 200:
66+
return value
67+
68+
slices = [
69+
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
70+
]
71+
return value[tuple(slices)]
72+
73+
74+
dumper = _Dumper()

0 commit comments

Comments
 (0)