Skip to content

Commit f37e5eb

Browse files
committed
debug h100 disagg perf test
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent e5c4865 commit f37e5eb

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -640,18 +640,18 @@ def __init__(
640640
def get_scores(logits, e_score_correction_bias):
641641
scores = F.sigmoid(logits)
642642
scores_with_bias = scores + e_score_correction_bias
643+
return scores, scores_with_bias
644+
645+
def noaux_tc(self, logits, e_score_correction_bias):
646+
n_group = self.n_group
647+
643648
if enable_llm_debug():
644649
has_nan = torch.isnan(scores_with_bias).any()
645650
if has_nan:
646651
warnings.warn(
647652
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
648653
)
649654

650-
return scores, scores_with_bias
651-
652-
def noaux_tc(self, logits, e_score_correction_bias):
653-
n_group = self.n_group
654-
655655
_, num_experts = logits.shape
656656
if self.n_group > 1:
657657
if self.top_k > 8 or (num_experts / n_group) > 32 or (
@@ -672,13 +672,6 @@ def noaux_tc(self, logits, e_score_correction_bias):
672672
if not self.is_fused:
673673
scores, scores_with_bias = Deepseekv3RoutingImpl.get_scores(
674674
logits, e_score_correction_bias)
675-
if enable_llm_debug():
676-
has_nan = torch.isnan(scores_with_bias).any()
677-
if has_nan:
678-
warnings.warn(
679-
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
680-
)
681-
682675
scores_shape = list(scores_with_bias.shape)
683676
group_scores = torch.sum(torch.topk(
684677
scores_with_bias.view(scores_shape[:-1] +

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,69 @@
2525
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
2626
LlmapiAccuracyTestHarness, get_accuracy_task)
2727

28+
MAX_PERF_METRICS_REQUESTS = 100
29+
30+
31+
def get_worker_env_vars(kv_cache_perf_dir: str = None):
32+
env = os.environ.copy()
33+
if kv_cache_perf_dir:
34+
env["TRTLLM_KVCACHE_TIME_OUTPUT_PATH"] = kv_cache_perf_dir
35+
return env
36+
37+
38+
def show_debug_perf(thread_pool: ThreadPoolExecutor,
39+
kv_cache_perf_dir: str = None,
40+
perf_metrics_url: str = None):
41+
42+
def wait_for_all_tasks_to_complete():
43+
#thread_pool.shutdown(wait=True)
44+
try:
45+
print("Waiting for all tasks to complete")
46+
for future in getattr(thread_pool, "futures", []):
47+
try:
48+
future.result(timeout=300)
49+
except concurrent.futures.TimeoutError:
50+
print("Timeout waiting for a future to complete.")
51+
except Exception as e:
52+
print(f"Future completed with error: {e}")
53+
except Exception as e:
54+
print(f"Error while waiting for futures: {e}")
55+
56+
def show_kvcache_time(kv_cache_perf_dir, max_lines=100):
57+
for file in os.listdir(kv_cache_perf_dir):
58+
print(f"{'-'*25} {file}:{max_lines} {'-'*25}")
59+
with open(os.path.join(kv_cache_perf_dir, file), "r") as f:
60+
for line in f.readlines()[-max_lines:]:
61+
print(line.strip())
62+
63+
def show_perf_metrics(url):
64+
perf_url = f"{url}/perf_metrics"
65+
try:
66+
print(f"Fetching perf metrics from {perf_url}")
67+
resp = requests.get(perf_url, timeout=10)
68+
if resp.status_code == 200:
69+
try:
70+
print("perf_metrics JSON:")
71+
metrics = resp.json()
72+
print(json.dumps(metrics, indent=2, ensure_ascii=False))
73+
print("-" * 100)
74+
except ValueError:
75+
print("perf_metrics returned non-JSON response:", resp.text)
76+
else:
77+
print(
78+
f"perf_metrics returned status {resp.status_code}: {resp.text}"
79+
)
80+
except requests.exceptions.RequestException as e:
81+
print(f"Error fetching {perf_url}: {e}")
82+
83+
wait_for_all_tasks_to_complete()
84+
if kv_cache_perf_dir:
85+
show_kvcache_time(kv_cache_perf_dir)
86+
if perf_metrics_url:
87+
show_perf_metrics(perf_metrics_url)
88+
# force failure to see the logs
89+
assert False
90+
2891

2992
class Result(GenerationResultBase):
3093

@@ -76,15 +139,29 @@ def launch_disaggregated_llm(
76139
ctx_model: str = None,
77140
gen_model: str = None,
78141
server_waiting_timeout: int = DEFAULT_SERVER_WAITING_TIMEOUT,
79-
max_workers: int = 16):
142+
max_workers: int = 16,
143+
debug_perf: bool = False):
80144
temp_dir = tempfile.TemporaryDirectory()
81145
disaggregated_serving_config_path = os.path.join(
82146
temp_dir.name, "disaggregated_serving_config.yaml")
83-
147+
if debug_perf:
148+
kv_cache_perf_dir = os.path.join(temp_dir.name, "kvcache_perf")
149+
os.makedirs(kv_cache_perf_dir, exist_ok=True)
150+
else:
151+
kv_cache_perf_dir = None
84152
if tensor_parallel_size > 1:
85153
print(
86154
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
87155
)
156+
if debug_perf:
157+
disaggregated_server_config[
158+
"perf_metrics_max_requests"] = MAX_PERF_METRICS_REQUESTS
159+
ctx_server_config["return_perf_metrics"] = True
160+
ctx_server_config[
161+
"perf_metrics_max_requests"] = MAX_PERF_METRICS_REQUESTS
162+
gen_server_config["return_perf_metrics"] = True
163+
gen_server_config[
164+
"perf_metrics_max_requests"] = MAX_PERF_METRICS_REQUESTS
88165

89166
with open(disaggregated_serving_config_path, "w") as f:
90167
yaml.dump(disaggregated_server_config, f)
@@ -144,7 +221,7 @@ def launch_disaggregated_llm(
144221
current_gpu_offset = 0
145222

146223
for i, port in enumerate(ctx_ports):
147-
env_ctx = os.environ.copy()
224+
env_ctx = get_worker_env_vars(kv_cache_perf_dir=kv_cache_perf_dir)
148225
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
149226
gpu_range = range(current_gpu_offset,
150227
current_gpu_offset + ctx_total_gpus)
@@ -165,7 +242,7 @@ def launch_disaggregated_llm(
165242
gen_servers = []
166243

167244
for i, port in enumerate(gen_ports):
168-
env_gen = os.environ.copy()
245+
env_gen = get_worker_env_vars(kv_cache_perf_dir=kv_cache_perf_dir)
169246
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
170247
gpu_range = range(current_gpu_offset,
171248
current_gpu_offset + gen_total_gpus)
@@ -289,6 +366,14 @@ def generate_async(prompt: str,
289366
tokenizer = load_hf_tokenizer(model_name)
290367
yield DuckLLM(args, tokenizer, generate_async)
291368

369+
if debug_perf:
370+
show_debug_perf(
371+
thread_pool,
372+
kv_cache_perf_dir=kv_cache_perf_dir,
373+
perf_metrics_url=f"http://localhost:8000"
374+
if debug_perf else None,
375+
)
376+
292377

293378
def run_parallel_test(model_name: str,
294379
model_path: str,
@@ -357,7 +442,7 @@ def run_parallel_test(model_name: str,
357442
task.evaluate(llm)
358443

359444

360-
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
445+
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT * 5)
361446
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
362447
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
363448
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
@@ -510,9 +595,13 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
510595
"urls": ["localhost:8002"]
511596
}
512597
}
513-
with launch_disaggregated_llm(disaggregated_server_config,
514-
ctx_server_config, gen_server_config,
515-
self.MODEL_PATH) as llm:
598+
with launch_disaggregated_llm(
599+
disaggregated_server_config,
600+
ctx_server_config,
601+
gen_server_config,
602+
self.MODEL_PATH,
603+
debug_perf=True,
604+
) as llm:
516605
task = GSM8K(self.MODEL_NAME)
517606
task.evaluate(llm)
518607

0 commit comments

Comments
 (0)