Skip to content

Commit 20e43fd

Browse files
committed
torch compile sampling kernel
Signed-off-by: Andy Lo <[email protected]>
1 parent c2dd566 commit 20e43fd

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.model_executor.model_loader import get_model
1313
from vllm.model_executor.models import supports_multimodal
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
15+
from vllm.platforms import current_platform
1516
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
1617
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1718
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -329,7 +330,7 @@ def prepare_inputs(
329330

330331
def load_model(self, target_model: nn.Module) -> None:
331332
draft_model_config = \
332-
self.vllm_config.speculative_config.draft_model_config
333+
self.speculative_config.draft_model_config
333334
target_attn_layer_names = set(
334335
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
335336

@@ -371,7 +372,7 @@ def load_model(self, target_model: nn.Module) -> None:
371372
# share lm_head with the target model if needed
372373
# some model definition do not define lm_head explicitly
373374
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
374-
if self.vllm_config.speculative_config.method != "eagle3" and \
375+
if self.speculative_config.method != "eagle3" and \
375376
hasattr(target_language_model, "lm_head"):
376377
logger.info("Loading EAGLE LM head weights from the target model.")
377378
self.model.lm_head = target_language_model.lm_head
@@ -383,11 +384,18 @@ def dummy_run(
383384
) -> None:
384385
with set_forward_context(None, self.vllm_config,
385386
num_tokens=num_tokens):
386-
self.model(
387+
ret_hidden_states = self.model(
387388
self.input_ids[:num_tokens],
388389
self.positions[:num_tokens],
389390
self.hidden_states[:num_tokens],
390391
)
392+
if self.method == "deepseek_mtp":
393+
last_hidden_states = ret_hidden_states
394+
else:
395+
last_hidden_states, hidden_states = ret_hidden_states
396+
logits = self.model.compute_logits(last_hidden_states, None)
397+
temperature = torch.ones(num_tokens, device=logits.device)
398+
_mixed_sample(logits, temperature)
391399

392400
def validate_same_kv_cache_group(self,
393401
kv_cache_config: KVCacheConfig) -> None:
@@ -409,21 +417,14 @@ def validate_same_kv_cache_group(self,
409417
) == 1, "All eagle layers should belong to the same kv cache group"
410418

411419

412-
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
413-
# We should refactor this to reuse the same sampling implementation.
414-
def compute_probs_and_sample_next_token(
415-
logits: torch.Tensor,
416-
sampling_metadata: SamplingMetadata,
417-
) -> tuple[torch.Tensor, torch.Tensor]:
418-
if sampling_metadata.all_greedy:
419-
# For greedy requests, draft_probs is not used in rejection sampling.
420-
# Therefore, we can just return the logits.
421-
probs = logits
422-
next_token_ids = logits.argmax(dim=-1)
423-
return next_token_ids, probs
424-
425-
is_greedy = sampling_metadata.temperature == -1
426-
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
420+
@torch.compile(dynamic=True,
421+
backend=current_platform.simple_compile_backend,
422+
mode="max-autotune-no-cudagraphs")
423+
def _mixed_sample(
424+
logits: torch.Tensor,
425+
temperature: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
426+
is_greedy = temperature == -1
427+
temperature = torch.where(is_greedy, 1.0, temperature)
427428
logits.div_(temperature.view(-1, 1))
428429
probs = logits.softmax(dim=-1, dtype=torch.float32)
429430

@@ -435,14 +436,23 @@ def compute_probs_and_sample_next_token(
435436
# TODO(woosuk): Consider seeds.
436437
q = torch.empty_like(probs)
437438
q.exponential_()
439+
q[is_greedy, :] = 1.0
438440
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
439441
# will be used later for rejection sampling.
440442
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
441-
if not sampling_metadata.all_random:
442-
greedy_token_ids = probs.argmax(dim=-1)
443-
next_token_ids = torch.where(
444-
is_greedy,
445-
greedy_token_ids,
446-
next_token_ids,
447-
)
448443
return next_token_ids, probs
444+
445+
446+
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
447+
# We should refactor this to reuse the same sampling implementation.
448+
def compute_probs_and_sample_next_token(
449+
logits: torch.Tensor,
450+
sampling_metadata: SamplingMetadata,
451+
) -> tuple[torch.Tensor, torch.Tensor]:
452+
if sampling_metadata.all_greedy:
453+
# For greedy requests, draft_probs is not used in rejection sampling.
454+
# Therefore, we can just return the logits.
455+
probs = logits
456+
next_token_ids = logits.argmax(dim=-1)
457+
return next_token_ids, probs
458+
return _mixed_sample(logits, sampling_metadata.temperature)

0 commit comments

Comments
 (0)