Skip to content
7 changes: 7 additions & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -69,6 +71,10 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument(
"--enable-draft-probs", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument("--request-id-prefix", type=str, default="")
return parser.parse_args()


Expand Down Expand Up @@ -110,6 +116,7 @@ def main():
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"enable_draft_probs": args.enable_draft_probs,
}
elif args.method == "ngram":
speculative_config = {
Expand Down
285 changes: 255 additions & 30 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _create_proposer(
method: str,
num_speculative_tokens: int,
speculative_token_tree: Optional[list[tuple[int]]] = None,
enable_probs: bool = True,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir,
runner="generate",
Expand All @@ -48,6 +49,7 @@ def _create_proposer(
method=method,
num_speculative_tokens=num_speculative_tokens,
speculative_token_tree=spec_token_tree_str,
enable_draft_probs=enable_probs,
)

vllm_config = VllmConfig(
Expand Down Expand Up @@ -228,7 +230,9 @@ class _TargetModelStub(LlamaForCausalLM):
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
@pytest.mark.parametrize("enable_probs", [True, False])
def test_propose_deterministic(method, attn_backend, num_speculative_tokens,
enable_probs, monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

Expand Down Expand Up @@ -256,7 +260,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2]

# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
proposer = _create_proposer("eagle",
num_speculative_tokens,
enable_probs=enable_probs)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size

Expand Down Expand Up @@ -341,6 +347,9 @@ def create_deterministic_logits(token_ids):
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Simulate mixed greedy and non-greedy requests
sampling_metadata.all_greedy = False
sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device)

if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
Expand All @@ -366,33 +375,247 @@ def create_deterministic_logits(token_ids):
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder

result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result, result_probs = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

# Example for num_speculative_tokens=1:
# [[42], [60]]
# Example for num_speculative_tokens=3:
# [[42, 43, 44], [60, 61, 62]]
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
dtype=torch.int64,
device=device)
expected_probs = torch.zeros(
(batch_size, num_speculative_tokens, vocab_size), device=device)
for i in range(batch_size):
for j in range(num_speculative_tokens):
expected_tokens[i, j] = base_token_ids[i] + j
expected_probs[i, j, base_token_ids[i] + j] = 1.0

# Verify all tokens match our expectations
assert result.shape == (batch_size, num_speculative_tokens)
assert torch.equal(result, expected_tokens)

# Create expected tokens based on our token pattern
if num_speculative_tokens == 1:
# Example for num_speculative_tokens=1:
# [[42], [60]]
expected_tokens = torch.tensor(
[[base_token_ids[0]], [base_token_ids[1]]], device=device)
if enable_probs:
assert result_probs is not None
assert result_probs.shape == (batch_size, num_speculative_tokens,
vocab_size)
torch.testing.assert_close(result_probs, expected_probs)
else:
# Example for num_speculative_tokens=3:
# [[42, 43, 44], [60, 61, 62]]
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
dtype=torch.int64,
device=device)
for i in range(batch_size):
for j in range(num_speculative_tokens):
expected_tokens[i, j] = base_token_ids[i] + j
assert result_probs is None

# Verify all tokens match our expectations
assert torch.equal(result, expected_tokens)

@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
@pytest.mark.parametrize("enable_probs", [True, False])
def test_propose_random(method, attn_backend, num_speculative_tokens,
enable_probs, monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if (attn_backend == "TREE_ATTN"):
pytest.skip("TREE_ATTN is tested separately in test_propose_tree"
"because it requires special input mocking.")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Use GPU device
device = torch.device(current_platform.device_type)

# Setup test parameters
batch_size = 2
seq_len_1 = 5
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 3
seq_lens = [seq_len_1, seq_len_2]

# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle",
num_speculative_tokens,
enable_probs=enable_probs)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size

# We mock a model that returns constant logits
# Sequence 1: [P(0) = 0.5, P(1) = 0.3, P(2) = 0.2] * num_speculative_tokens
# Sequence 2: [P(0) = 0.2, P(1) = 0.4, P(2) = 0.4] * num_speculative_tokens
token_probs = torch.tensor([
[0.5, 0.3, 0.2],
[0.2, 0.4, 0.4],
],
device=device)

def sample_once():
# Skip loading the model and replace it with a mock directly
# Create the mock model with deterministic outputs
model_mock = mock.MagicMock()

# Setup for model forward calls
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
# First call uses all tokens
h_logits = torch.zeros(total_tokens,
hidden_size,
device=device)
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Subsequent calls use batch_size tokens
h_logits = torch.zeros(batch_size, hidden_size, device=device)
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append((h_logits, h_states))

model_mock.side_effect = forward_returns

# Setup for compute_logits calls
logits_returns = []
for i in range(num_speculative_tokens):
# Subtracting a constant doesn't change the logits
logits = torch.log(token_probs) - torch.randn(
(batch_size, 1), device=device)
logits_returns.append(logits)

model_mock.compute_logits.side_effect = logits_returns

# Assign the mock to the proposer
proposer.model = model_mock

# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]

# Create input tensors
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_len_1, device=device),
torch.arange(seq_len_2, device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Simulate mixed greedy and non-greedy requests
sampling_metadata.all_greedy = False
# Greedy sampling for seq 1, standard sampling for seq 2
sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device)

if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
elif attn_backend == "TRITON_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TREE_ATTN)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)

# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][
0].metadata_builder = attn_metadata_builder

result, result_prob = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

return result, result_prob

results = []
result_probs = []

# Run N times and check distribution
N = 1000
for _ in range(N):
result, result_prob = sample_once()
results.append(result)
result_probs.append(result_prob)

# Count the number of times each token appears
counts = torch.zeros((batch_size, num_speculative_tokens, vocab_size),
device=device,
dtype=torch.int64)
for result in results:
assert result.shape == (batch_size, num_speculative_tokens)
counts.scatter_add_(2, result.unsqueeze(-1),
torch.ones_like(result.unsqueeze(-1)))
sample_dist = counts / len(results)

token_probs_after_temp = torch.tensor(
[
[1, 0, 0],
[0.1567, 0.4217, 0.4217] if enable_probs else
[0, 1, 0], # argmax tie-breaks on first occurrence
],
device=device)

# Verify that the observed distribution is within 4 standard deviations
std = torch.sqrt(token_probs_after_temp * (1 - token_probs_after_temp) / N)
assert torch.all(std <= 0.02), f"Bounds {std=} are too loose, increase N"
lower_bound = token_probs_after_temp - 4 * std
upper_bound = token_probs_after_temp + 4 * std
assert torch.all(sample_dist >= lower_bound.unsqueeze(1)), (
f"Sampled too many unlikely tokens: {sample_dist} < {lower_bound}")
assert torch.all(sample_dist <= upper_bound.unsqueeze(1)), (
f"Sampled too few likely tokens: {sample_dist} > {upper_bound}")

if enable_probs:
for result_prob in result_probs:
assert result_prob is not None
assert result_prob.shape == (batch_size, num_speculative_tokens,
vocab_size)
# only check sequence 2, since sequence 1 is greedy, so the probs
# are allowed to be anything
assert torch.allclose(result_prob[1],
token_probs_after_temp[1].unsqueeze(0),
atol=1e-3)
else:
assert all(result_prob is None for result_prob in result_probs)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -517,13 +740,15 @@ def create_deterministic_logits(token_ids, k: int):
sampling_metadata = mock.MagicMock()

# Propose draft tokens.
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
result, draft_probs = proposer.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)
assert draft_probs is None

# The tokens are expected to be consecutive integers starting
# from the base token IDs.
Expand Down
Loading