diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 1d08377ea0..712418d3e0 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -36,7 +36,8 @@ ("google/gemma-7b", 1, False, 109.70751574382221), ], "fp8": [ - ("tiiuae/falcon-180B", 52.85086442722326), + # ("tiiuae/falcon-180B", 52.85086442722326), + ("mistralai/Mistral-7B-Instruct-v0.2", 0), ], "deepspeed": [ ("bigscience/bloomz", 36.77314954096159), @@ -81,6 +82,16 @@ } +MISTRAL_FP8_CONFIG = { + "mistralai/Mistral-7B-Instruct-v0.2": [ + ("896", "128", "128", 13310.566520719813), + ("120", "128", "2048", 7757.383448024244), + ("120", "2048", "128", 1352.070452897798), + ("44", "2048", "2048", 3101.5205518843136), + ], +} + + def _test_text_generation( model_name: str, baseline: float, @@ -135,6 +146,12 @@ def _test_text_generation( "--trim_logits", ] + if "Mistral" in model_name: + command += [ + "--attn_softmax_bf16", + ] + command.remove("--max_new_tokens 100") + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") print(f"\n\nCommand to test: {' '.join(command)}\n") @@ -154,6 +171,34 @@ def _test_text_generation( ) command.insert(-2, "--fp8") + if "Mistral" in model_name: + command.insert(-2, "--limit_hpu_graphs") + command.insert(-2, "--max_input_tokens 1") + command.insert(-2, "--max_new_tokens 1") + command = [x for y in command for x in re.split(pattern, y) if x] + for model_config in MISTRAL_FP8_CONFIG[model_name]: + command[command.index("--batch_size") + 1] = model_config[0] + command[command.index("--max_input_tokens") + 1] = model_config[1] + command[command.index("--max_new_tokens") + 1] = model_config[2] + baseline = model_config[3] + proc = subprocess.run(command, env=env_variables) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) + raise + + with open(Path(tmp_dir) / "results.json") as fp: + results = json.load(fp) + + # Ensure performance requirements (throughput) are met + assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline + return + proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue