diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 8eb9acfada2..c9b9fa979fe 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -65,7 +65,7 @@ from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer -from ..speculative import MTPSpecMetadata, SpecMetadata +from ..speculative import SpecMetadata from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, @@ -230,7 +230,7 @@ def __init__( aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 super().__init__(hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, @@ -750,6 +750,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -765,16 +766,24 @@ def forward( **kwargs, ) if isinstance(self.mlp, Deepseekv3MoE): + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False return self.forward_MoE( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False assert isinstance(self.mlp, GatedMLP) return self.forward_mlp( hidden_states=hidden_states, residual=residual, + spec_metadata=spec_metadata, ) def forward_MoE( @@ -782,6 +791,7 @@ def forward_MoE( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): @@ -856,6 +866,10 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -866,6 +880,7 @@ def forward_mlp( self, hidden_states: torch.Tensor, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.fusion_config.PRE_MLP_FUSION: @@ -903,6 +918,10 @@ def forward_mlp( ), ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -1105,6 +1124,7 @@ def forward( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) return hidden_states @@ -1132,7 +1152,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config=model_config) self.model_nextn = 0 - if model_config.spec_config is not None: + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp( + ): model_nextn = model_config.spec_config.num_nextn_predict_layers ckpt_nextn = self.config.num_nextn_predict_layers self.num_hidden_layers = self.config.num_hidden_layers @@ -1167,11 +1188,10 @@ def forward( input_ids: torch.IntTensor = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[MTPSpecMetadata] = None, + spec_metadata: Optional[SpecMetadata] = None, return_context_logits: bool = False, **kwargs, ) -> torch.Tensor: - attn_metadata.num_generations_per_batch = self.model_nextn + 1 return super().forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, @@ -1313,7 +1333,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): - if len(module._parameters) > 0: + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: names = name.split('.') parent_module_name = '.'.join(names[:-1]) if "model.layers" in name and int( diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index f82c3b4de06..56a489c9635 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -155,10 +155,12 @@ def __init__( else: self.hidden_size_in = config.hidden_size - self.fc = Linear(self.hidden_size_in * 3, - config.hidden_size, - bias=getattr(config, "bias", False), - dtype=config.torch_dtype) + if self.spec_config.num_capture_layers > 1: + self.fc = Linear(self.hidden_size_in * + self.spec_config.num_capture_layers, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=config.torch_dtype) self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 12686728cda..ac3bb7a9f53 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -23,7 +23,7 @@ get_spec_resource_manager) from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, create_py_executor_instance, instantiate_sampler, is_mla) -from .config import PyTorchConfig +from .config import LoadFormat, PyTorchConfig from .config_utils import is_mla from .guided_decoder import GuidedDecoder from .model_engine import PyTorchModelEngine @@ -252,13 +252,16 @@ def create_py_executor( with mem_monitor.observe_creation_stage( _ExecutorCreationStage.MODEL_ENGINE_DRAFT): draft_spec_config = copy.copy(spec_config) + draft_pytorch_backend_config = copy.copy(pytorch_backend_config) + if spec_config.load_format == "dummy": + draft_pytorch_backend_config.load_format = LoadFormat.DUMMY # The draft model won't have any draft tokens attached to # generation requests when we invoke it autoregressively draft_spec_config.max_draft_len = 0 draft_model_engine = PyTorchModelEngine( model_path=spec_config.speculative_model_dir, - pytorch_backend_config=pytorch_backend_config, + pytorch_backend_config=draft_pytorch_backend_config, batch_size=executor_config.max_batch_size, max_beam_width=executor_config.max_beam_width, max_num_tokens=executor_config.max_num_tokens, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 417becf12f3..2d4225641b5 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List, Optional, Set import torch from torch import nn @@ -35,9 +35,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype, # empty hidden states tensor max_num_tokens = min(max_num_tokens, max_num_requests * self.max_seq_len) - self.hidden_states = torch.empty((max_num_tokens, self.hidden_size * 3), - dtype=self.dtype, - device='cuda') + self.hidden_states = torch.empty( + (max_num_tokens, self.hidden_size * config.num_capture_layers), + dtype=self.dtype, + device='cuda') # sequence length, only used for metadata preparation self.seq_lens = {i: 0 for i in range(max_num_requests)} # start indices of each slot @@ -79,8 +80,7 @@ def get_needed_resource_to_completion(self, request: LlmRequest): @dataclass class Eagle3SpecMetadata(SpecMetadata): hidden_states: List[torch.Tensor] = field(default_factory=list) - num_capture_layers: int = 3 - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None target_model_embed_tokens: Optional[torch.nn.Module] = None hidden_size: int = 0 max_num_tokens: int = 0 @@ -90,14 +90,19 @@ class Eagle3SpecMetadata(SpecMetadata): eagle3_resource_manager: Optional[Eagle3ResourceManager] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (0, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) # Initialize to 0 to avoid reading uninitialized memory during warmup self.hidden_states_read_indices = torch.zeros([self.max_num_tokens], @@ -186,7 +191,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states hidden_states: Optional[torch.Tensor] = None # The layers to be captured - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None # The hidden size of the hidden states hidden_size: int = 0 # The max number of tokens @@ -197,14 +202,19 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): batch_indices_cuda: Optional[torch.Tensor] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (1, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) self.hidden_states = torch.empty( (self.max_num_tokens, self.hidden_size * len(self.layers_to_capture)), diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index f7cdd92a561..1d306b90291 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -185,6 +185,13 @@ def create_cuda_graph_metadata(self, max_batch_size: int): cuda_graph_metadata.__post_init__() return cuda_graph_metadata + def is_layer_capture(self, layer_id: int): + """ + Whether the layer should be captured (eg for Eagle3). + By default, does nothing. + """ + return False + def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index c4a4ccf7e3c..16fef4862b3 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -38,6 +38,7 @@ def get_spec_metadata(spec_config, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelSpecMetadata( @@ -47,6 +48,7 @@ def get_spec_metadata(spec_config, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index da5071e3b0a..6ed4dea76c7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -9,7 +9,7 @@ from enum import Enum, EnumMeta from pathlib import Path from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, - Type, TypeAlias, TypeVar, Union, get_args, get_origin) + Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin) import torch import yaml @@ -352,6 +352,7 @@ class DecodingBaseConfig(StrictBaseModel): # When specified, speculation will be disabled at batch sizes above # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + load_format: Optional[str] = None @classmethod def from_dict(cls, data: dict): @@ -424,6 +425,7 @@ class EagleDecodingConfig(DecodingBaseConfig): num_eagle_layers: Optional[int] = None max_non_leaves_per_layer: Optional[int] = None eagle3_one_model: Optional[bool] = True + eagle3_layers_to_capture: Optional[Set[int]] = None @classmethod def from_dict(cls, data: dict): @@ -443,6 +445,17 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL return TorchSpeculativeDecodingMode.EAGLE3 + @functools.cached_property + def num_capture_layers(self): + """ + Returns the number of layers to capture of the target model. + If eagle3_layers_to_capture is not None, return the length of the set. + Otherwise, assume Eagle3 base set and return 3. + """ + if self.eagle3_layers_to_capture is not None: + return len(self.eagle3_layers_to_capture) + return 3 + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports @@ -523,7 +536,9 @@ class MTPDecodingConfig(DecodingBaseConfig): @classmethod def from_dict(cls, data: dict): - return cls(**data) + out = cls(**data) + out.max_draft_len = out.num_nextn_predict_layers + return out decoding_type: ClassVar[str] = "MTP" diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index ffb8e33766a..f26fa244f1f 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -1,6 +1,9 @@ +import json import os import sys +import tempfile import unittest +from pathlib import Path import pytest import torch @@ -120,5 +123,107 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, assert text_spec == text_ref +def test_deepseek_eagle3(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + use_one_model = False + enable_chunked_prefill = False + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 150: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + eagle_config = { + 'architectures': ['LlamaForCausalLMEagle3'], + 'attention_bias': False, + 'attention_dropout': 0.0, + 'bos_token_id': 128000, + 'eos_token_id': [128001, 128008, 128009], + 'eagle_config': { + 'use_aux_hidden_state': False, + 'use_input_layernorm_in_first_layer': True, + 'use_last_layernorm': True, + 'use_mtp_layernorm': False + }, + 'head_dim': 128, + 'hidden_act': 'silu', + 'hidden_size': 2560, + 'initializer_range': 0.02, + 'intermediate_size': 16384, + 'max_position_embeddings': 4096, + 'mlp_bias': False, + 'model_type': 'llama', + 'num_attention_heads': 32, + 'num_eagle_features': 1, + 'num_hidden_layers': 1, + 'num_key_value_heads': 8, + 'pretraining_tp': 1, + 'rms_norm_eps': 1e-05, + 'rope_scaling': { + 'factor': 8.0, + 'high_freq_factor': 4.0, + 'low_freq_factor': 1.0, + 'original_max_position_embeddings': 8192, + 'rope_type': 'llama3' + }, + 'rope_theta': 500000.0, + 'tie_word_embeddings': False, + 'torch_dtype': 'bfloat16', + 'transformers_version': '4.52.4', + 'use_cache': True, + 'vocab_size': 129280, + 'draft_vocab_size': 129280 + } + with tempfile.TemporaryDirectory() as temp_dir: + eagle_model_dir = Path(temp_dir) + config_path = eagle_model_dir / "config.json" + with config_path.open("w") as f: + json.dump(eagle_config, f, indent=2) + target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only" + + # bs > 1 gives non-deterministic when doing IFB. There are slight chances + # that ref and spec does not match 100% + max_batch_size = 16 + max_draft_len = 3 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + max_num_tokens=4096, + max_seq_len=4096, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + # Llama 3 does not support one model eagle. + eagle3_one_model=use_one_model, + eagle3_layers_to_capture={29}, + load_format="dummy") + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + if __name__ == "__main__": unittest.main()