Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d1c8724
first draft: working PARD implementation
benchislett Jan 22, 2026
f329963
PTD EAGLE support
benchislett Jan 24, 2026
da7f947
port bugfix from PTD branch
benchislett Jan 24, 2026
32fc2a0
typo
benchislett Jan 24, 2026
c597662
typo
benchislett Jan 24, 2026
543e011
avoid syncs
benchislett Jan 26, 2026
9514ff5
Merge branch 'main' into bchislett/unified-parallel-drafting
benchislett Jan 26, 2026
24aba2c
patch
benchislett Jan 27, 2026
e91ae70
Merge branch 'main' into bchislett/unified-parallel-drafting
benchislett Jan 30, 2026
aca2169
cleanup for PR
benchislett Jan 30, 2026
a17bc22
simplify testing
benchislett Jan 30, 2026
6a9058e
warn/err when parallel drafting misconfigured
benchislett Jan 30, 2026
23ab49e
fix lint issues in flashinfer cg support
benchislett Jan 30, 2026
ccaf85f
rename var for clarity
benchislett Jan 30, 2026
5ad5f78
rename some vars
benchislett Jan 30, 2026
bd91808
add targeted tests for parallel input preparation
benchislett Jan 31, 2026
9b1e305
Merge branch 'main' into bchislett/unified-parallel-drafting
benchislett Jan 31, 2026
733bb81
fix lint
benchislett Jan 31, 2026
fef4f5b
comments in flashinfer diff
benchislett Feb 4, 2026
0130080
Merge branch 'main' into bchislett/unified-parallel-drafting
benchislett Feb 4, 2026
972ea76
remove is_draft_model checks in eagle.py
benchislett Feb 4, 2026
70cee26
fix bug in test setup
benchislett Feb 4, 2026
213ff8e
refactors for tests
benchislett Feb 4, 2026
cb564b9
Rename last_token_indices to token_indices_to_sample
benchislett Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def parse_args():
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--disable-padded-drafter-batch", action="store_true")
parser.add_argument("--max-num-seqs", type=int, default=None)
parser.add_argument("--parallel-drafting", action="store_true")
parser.add_argument("--allowed-local-media-path", type=str, default="")
return parser.parse_args()

Expand Down Expand Up @@ -121,6 +122,7 @@ def main(args):
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"disable_padded_drafter_batch": args.disable_padded_drafter_batch,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "ngram":
speculative_config = {
Expand All @@ -137,6 +139,7 @@ def main(args):
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
"parallel_drafting": args.parallel_drafting,
}
elif args.method == "mtp":
speculative_config = {
Expand Down
95 changes: 32 additions & 63 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config.vllm import VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.draft_model import (
create_vllm_config_for_draft_model,
merge_toks_kernel,
)
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model

MTP_SIMILARITY_RATE = 0.8

Expand Down Expand Up @@ -625,6 +622,8 @@ class ArgsTest:
expected_acceptance_rate: float
expected_acceptance_len: float
# Defaults
enforce_eager: bool = True
parallel_drafting: bool = False
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
Expand Down Expand Up @@ -658,7 +657,8 @@ class ArgsTest:
@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
assert_draft_model_correctness(args, enforce_eager)
args.enforce_eager = enforce_eager
assert_draft_model_correctness(args)


def test_draft_model_realistic_example():
Expand All @@ -668,11 +668,28 @@ def test_draft_model_realistic_example():
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
enforce_eager=False,
# values below are not derived, but just prevent a regression
expected_acceptance_len=2.8,
expected_acceptance_rate=0.55,
)
assert_draft_model_correctness(args, enforce_eager=False)
assert_draft_model_correctness(args)


def test_draft_model_parallel_drafting():
args = ArgsTest(
target_model="Qwen/Qwen3-1.7B",
draft_model="amd/PARD-Qwen3-0.6B",
dataset="likaixin/InstructCoder",
num_speculative_tokens=3,
sampling_config=greedy_sampling(),
parallel_drafting=True,
enforce_eager=False,
# values below are collected from a stable run, with ~5% tolerance
expected_acceptance_len=2.375,
expected_acceptance_rate=0.45,
)
assert_draft_model_correctness(args)


@pytest.mark.parametrize(
Expand All @@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
target_model=tgt_model,
draft_model=draft_model,
**some_high_acceptance_metrics(),
enforce_eager=enforce_eager,
)
assert_draft_model_correctness(sd_case, enforce_eager)
assert_draft_model_correctness(sd_case)


def test_draft_model_tensor_parallelism():
Expand All @@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism():
draft_model="Qwen/Qwen3-0.6B",
draft_tensor_parallel_size=2,
**some_high_acceptance_metrics(),
enforce_eager=False,
)
assert_draft_model_correctness(sd_case, enforce_eager=False)
assert_draft_model_correctness(sd_case)


def test_draft_model_engine_args_tensor_parallelism():
Expand Down Expand Up @@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname():
engine_args.create_engine_config()


def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
def assert_draft_model_correctness(args: ArgsTest):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
test_prompts: list[Messages] = get_messages(
Expand All @@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"method": "draft_model",
"num_speculative_tokens": args.num_speculative_tokens,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"enforce_eager": args.enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"parallel_drafting": args.parallel_drafting,
},
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
enforce_eager=args.enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
# we don't check the outputs, only check the metrics
Expand Down Expand Up @@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict:
}


def test_merge_toks_kernel():
device = "cuda"
merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2
merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary
is_rejected_tok = torch.full((merged_len,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 3], device=device),
query_end_locs_ptr=torch.tensor([2, 4], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=5,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)
assert torch.allclose(merged, expected_merged)

expected_rejected_toks = torch.tensor([False] * merged_len, device=device)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)


def test_merge_toks_kernel_with_rejected_tokens():
device = "cuda"
merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2
merged = torch.full((merged_size,), -100, device=device)
is_rejected_tok = torch.full((merged_size,), True, device=device)
grid = (2,)
merge_toks_kernel[grid](
# rejected tokens
# ↓ ↓ ↓ ↓
target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device),
next_toks_ptr=torch.tensor([3, 2], device=device),
query_start_locs_ptr=torch.tensor([0, 6], device=device),
query_end_locs_ptr=torch.tensor([2, 7], device=device),
out_ptr_merged_toks=merged,
out_ptr_is_rejected_tok=is_rejected_tok,
target_toks_size=9,
rejected_tok_fill=-1,
)
expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device)
assert torch.allclose(merged, expected_merged)

expected_rejected_toks = torch.tensor(
[False, False, False, False, True, True, True, False, False, False, True],
device=device,
)
assert torch.allclose(is_rejected_tok, expected_rejected_toks)


def compute_acceptance_rate(metrics: list[Metric]) -> float:
name2metric = {metric.name: metric for metric in metrics}
n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore
Expand Down
Loading