1212from vllm .model_executor .model_loader import get_model
1313from vllm .model_executor .models import supports_multimodal
1414from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
15+ from vllm .platforms import current_platform
1516from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
1617from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1718from vllm .v1 .kv_cache_interface import KVCacheConfig
@@ -329,7 +330,7 @@ def prepare_inputs(
329330
330331 def load_model (self , target_model : nn .Module ) -> None :
331332 draft_model_config = \
332- self .vllm_config . speculative_config .draft_model_config
333+ self .speculative_config .draft_model_config
333334 target_attn_layer_names = set (
334335 get_layers_from_vllm_config (self .vllm_config , Attention ).keys ())
335336
@@ -371,7 +372,7 @@ def load_model(self, target_model: nn.Module) -> None:
371372 # share lm_head with the target model if needed
372373 # some model definition do not define lm_head explicitly
373374 # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
374- if self .vllm_config . speculative_config .method != "eagle3" and \
375+ if self .speculative_config .method != "eagle3" and \
375376 hasattr (target_language_model , "lm_head" ):
376377 logger .info ("Loading EAGLE LM head weights from the target model." )
377378 self .model .lm_head = target_language_model .lm_head
@@ -383,11 +384,18 @@ def dummy_run(
383384 ) -> None :
384385 with set_forward_context (None , self .vllm_config ,
385386 num_tokens = num_tokens ):
386- self .model (
387+ ret_hidden_states = self .model (
387388 self .input_ids [:num_tokens ],
388389 self .positions [:num_tokens ],
389390 self .hidden_states [:num_tokens ],
390391 )
392+ if self .method == "deepseek_mtp" :
393+ last_hidden_states = ret_hidden_states
394+ else :
395+ last_hidden_states , hidden_states = ret_hidden_states
396+ logits = self .model .compute_logits (last_hidden_states , None )
397+ temperature = torch .ones (num_tokens , device = logits .device )
398+ _mixed_sample (logits , temperature )
391399
392400 def validate_same_kv_cache_group (self ,
393401 kv_cache_config : KVCacheConfig ) -> None :
@@ -409,21 +417,14 @@ def validate_same_kv_cache_group(self,
409417 ) == 1 , "All eagle layers should belong to the same kv cache group"
410418
411419
412- # FIXME(woosuk): The logic here is duplicated with the main sampling code.
413- # We should refactor this to reuse the same sampling implementation.
414- def compute_probs_and_sample_next_token (
415- logits : torch .Tensor ,
416- sampling_metadata : SamplingMetadata ,
417- ) -> tuple [torch .Tensor , torch .Tensor ]:
418- if sampling_metadata .all_greedy :
419- # For greedy requests, draft_probs is not used in rejection sampling.
420- # Therefore, we can just return the logits.
421- probs = logits
422- next_token_ids = logits .argmax (dim = - 1 )
423- return next_token_ids , probs
424-
425- is_greedy = sampling_metadata .temperature == - 1
426- temperature = torch .where (is_greedy , 1.0 , sampling_metadata .temperature )
420+ @torch .compile (dynamic = True ,
421+ backend = current_platform .simple_compile_backend ,
422+ mode = "max-autotune-no-cudagraphs" )
423+ def _mixed_sample (
424+ logits : torch .Tensor ,
425+ temperature : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
426+ is_greedy = temperature == - 1
427+ temperature = torch .where (is_greedy , 1.0 , temperature )
427428 logits .div_ (temperature .view (- 1 , 1 ))
428429 probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
429430
@@ -435,14 +436,23 @@ def compute_probs_and_sample_next_token(
435436 # TODO(woosuk): Consider seeds.
436437 q = torch .empty_like (probs )
437438 q .exponential_ ()
439+ q [is_greedy , :] = 1.0
438440 # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
439441 # will be used later for rejection sampling.
440442 next_token_ids = probs .div (q ).argmax (dim = - 1 ).view (- 1 )
441- if not sampling_metadata .all_random :
442- greedy_token_ids = probs .argmax (dim = - 1 )
443- next_token_ids = torch .where (
444- is_greedy ,
445- greedy_token_ids ,
446- next_token_ids ,
447- )
448443 return next_token_ids , probs
444+
445+
446+ # FIXME(woosuk): The logic here is duplicated with the main sampling code.
447+ # We should refactor this to reuse the same sampling implementation.
448+ def compute_probs_and_sample_next_token (
449+ logits : torch .Tensor ,
450+ sampling_metadata : SamplingMetadata ,
451+ ) -> tuple [torch .Tensor , torch .Tensor ]:
452+ if sampling_metadata .all_greedy :
453+ # For greedy requests, draft_probs is not used in rejection sampling.
454+ # Therefore, we can just return the logits.
455+ probs = logits
456+ next_token_ids = logits .argmax (dim = - 1 )
457+ return next_token_ids , probs
458+ return _mixed_sample (logits , sampling_metadata .temperature )
0 commit comments