diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c4ea4b675649..49316eb4f607 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -329,6 +329,8 @@ steps: - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - label: Platform Tests (CUDA) # 4min timeout_in_minutes: 15 diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 004e75b20464..ce078bce0b75 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts): def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") parser.add_argument( "--method", type=str, @@ -60,6 +61,7 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -71,8 +73,7 @@ def parse_args(): return parser.parse_args() -def main(): - args = parse_args() +def main(args): args.endpoint_type = "openai-chat" model_dir = args.model_dir @@ -134,7 +135,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) @@ -198,6 +199,39 @@ def main(): acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 print(f"acceptance at token {i}: {acceptance_rate:.2f}") + return acceptance_length + if __name__ == "__main__": - main() + args = parse_args() + acceptance_length = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + )