Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def parse_args():
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"],
)
parser.add_argument(
"--parallel-draft",
action="store_true",
help="Generate all draft tokens in a single forward pass. "
"Requires a draft model trained for parallel drafting.",
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
parser.add_argument("--prompt-lookup-min", type=int, default=2)
Expand Down Expand Up @@ -104,18 +110,28 @@ def main(args):
else:
prompts = get_custom_mm_prompts(args.num_prompts)

if args.method == "eagle" or args.method == "eagle3":
if args.method in ("eagle", "eagle3"):
eagle_dir = args.eagle_dir
if args.method == "eagle" and eagle_dir is None:
if args.parallel_draft:
raise ValueError(
"--eagle-dir is required when using --parallel-draft. "
"No public parallel draft model is available yet."
)
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"

elif args.method == "eagle3" and eagle_dir is None:
if args.parallel_draft:
raise ValueError(
"--eagle-dir is required when using --parallel-draft. "
"No public parallel draft model is available yet."
)
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
speculative_config = {
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
"parallel_draft": args.parallel_draft,
}
elif args.method == "ngram":
speculative_config = {
Expand Down
92 changes: 92 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,98 @@ def test_eagle_correctness(
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
pytest.param(
(
"eagle3",
"openai/gpt-oss-120b",
"PATH_TO_PARALLEL_DRAFT_MODEL",
1,
),
False,
marks=pytest.mark.skip(
reason="Parallel draft model not publicly available yet"
),
),
],
ids=["gpt_oss_eagle3_ptd"],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_ptd_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
attn_backend: str,
):
"""
Compare the outputs of an original LLM and a speculative LLM
using parallel drafting.
Generates K draft tokens in a single forward pass using mask tokens.
model_setup: (method, model_name, draft_model_name, tp_size)
"""
if attn_backend == "TREE_ATTN":
pytest.skip("TREE_ATTN not yet supported with parallel drafting")

test_prompts = get_test_prompts(mm_enabled)
attention_config = {"backend": attn_backend}

if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip("TRITON_ATTN not supported on current platform")

with monkeypatch.context() as m:
m.setenv("VLLM_MLA_DISABLE", "1")

method, model_name, spec_model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)

max_model_len = 2048

ref_llm = LLM(
model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 6,
"max_model_len": max_model_len,
"parallel_draft": True,
},
max_model_len=max_model_len,
attention_config=attention_config,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 60% of the prompts to match exactly
assert matches > int(0.6 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
Expand Down
Loading
Loading