Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ steps:
- python3 offline_inference/basic/embed.py
- python3 offline_inference/basic/score.py
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill

- label: Platform Tests (CUDA) # 4min
timeout_in_minutes: 15
Expand Down
34 changes: 31 additions & 3 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts):
def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
parser.add_argument("--test", action="store_true")
parser.add_argument(
"--method",
type=str,
Expand All @@ -72,8 +73,7 @@ def parse_args():
return parser.parse_args()


def main():
args = parse_args()
def main(args):
args.endpoint_type = "openai-chat"

model_dir = args.model_dir
Expand Down Expand Up @@ -194,6 +194,34 @@ def main():
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
print(f"acceptance at token {i}: {acceptance_rate:.2f}")

return acceptance_length


if __name__ == "__main__":
main()
args = parse_args()
acceptance_length = main(args)

if args.test:
# takes ~30s to run on 1xH100
assert args.method == "eagle"
assert args.tp == 1
assert args.num_spec_tokens == 3
assert args.dataset_path == "philschmid/mt-bench"
assert args.num_prompts == 80
assert args.temp == 0
assert args.top_p == 1.0
assert args.top_k == -1
assert args.enable_chunked_prefill

# check acceptance length is within 1% of expected value
rtol = 0.01
expected_acceptance_length = 2.29
assert (
acceptance_length <= (1 + rtol) * expected_acceptance_length
and acceptance_length >= (1 - rtol) * expected_acceptance_length
), (
f"acceptance_length {acceptance_length} is not \
within {rtol * 100}% of {expected_acceptance_length}"
)

print("Test passed!")
Loading