diff --git a/test/srt/test_llama31_fp4.py b/test/srt/test_llama31_fp4.py index 1be9671842a1..36ae3697114f 100644 --- a/test/srt/test_llama31_fp4.py +++ b/test/srt/test_llama31_fp4.py @@ -14,15 +14,13 @@ @unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") -class TestLlama31FP4B200(unittest.TestCase): +class TestLlama31FP4(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = MODEL_PATH cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ "--trust-remote-code", - "--mem-fraction-static", - "0.8", "--quantization", "modelopt_fp4", ] @@ -40,18 +38,18 @@ def tearDownClass(cls): def test_gsm8k(self): parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=4, + num_shots=5, data_path=None, - num_questions=100, + num_questions=1319, max_new_tokens=512, - parallel=128, + parallel=200, host=f"{parsed_url.scheme}://{parsed_url.hostname}", port=parsed_url.port, ) metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.61) + self.assertGreater(metrics["accuracy"], 0.54) if __name__ == "__main__":