Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 1 addition & 67 deletions test/registered/quant/test_deepseek_v3_fp4_4gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
write_github_step_summary,
)

register_cuda_ci(est_time=1500, suite="stage-c-test-4-gpu-b200")
register_cuda_ci(est_time=1200, suite="stage-c-test-4-gpu-b200")

FULL_DEEPSEEK_V3_FP4_MODEL_PATH = "nvidia/DeepSeek-V3-0324-FP4"
SERVER_LAUNCH_TIMEOUT = 1200
Expand Down Expand Up @@ -72,72 +72,6 @@ def test_a_gsm8k(

self.assertGreater(metrics["accuracy"], 0.93)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
acc_length, speed = send_one_prompt(args)

print(f"{speed=:.2f}")

if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (deepseek-v3-fp4)\n" f"{speed=:.2f} token/s\n"
)

self.assertGreater(speed, 75)


class TestDeepseekV3FP4PiecewiseCudaGraph(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = FULL_DEEPSEEK_V3_FP4_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp",
"4",
"--attention-backend",
"trtllm_mla",
"--moe-runner-backend",
"flashinfer_trtllm",
"--quantization",
"modelopt_fp4",
"--kv-cache-dtype",
"fp8_e4m3",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=SERVER_LAUNCH_TIMEOUT,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(
self,
):
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1319,
parallel=1319,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n'
)

self.assertGreater(metrics["accuracy"], 0.93)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
_, speed = send_one_prompt(args)
Expand Down
Loading