Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..peft.lora.layer import LoraLayer
from ..speculative import MTPSpecMetadata, SpecMetadata
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
Expand Down Expand Up @@ -230,7 +230,7 @@ def __init__(
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
Expand Down Expand Up @@ -750,6 +750,7 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
Expand All @@ -765,23 +766,32 @@ def forward(
**kwargs,
)
if isinstance(self.mlp, Deepseekv3MoE):
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
self.fusion_config.POST_MOE_FUSION = False
return self.forward_MoE(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
self.fusion_config.POST_MLP_FUSION = False
assert isinstance(self.mlp, GatedMLP)
return self.forward_mlp(
hidden_states=hidden_states,
residual=residual,
spec_metadata=spec_metadata,
)

def forward_MoE(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
Expand Down Expand Up @@ -856,6 +866,10 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
Expand All @@ -866,6 +880,7 @@ def forward_mlp(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

if self.fusion_config.PRE_MLP_FUSION:
Expand Down Expand Up @@ -903,6 +918,10 @@ def forward_mlp(
),
)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
Expand Down Expand Up @@ -1105,6 +1124,7 @@ def forward(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
)

return hidden_states
Expand Down Expand Up @@ -1132,7 +1152,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
model_config=model_config)

self.model_nextn = 0
if model_config.spec_config is not None:
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp(
):
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
Expand Down Expand Up @@ -1167,11 +1188,10 @@ def forward(
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[MTPSpecMetadata] = None,
spec_metadata: Optional[SpecMetadata] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
attn_metadata.num_generations_per_batch = self.model_nextn + 1
return super().forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
Expand Down Expand Up @@ -1313,7 +1333,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,

for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) > 0:
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
names = name.split('.')
parent_module_name = '.'.join(names[:-1])
if "model.layers" in name and int(
Expand Down
10 changes: 6 additions & 4 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ def __init__(
else:
self.hidden_size_in = config.hidden_size

self.fc = Linear(self.hidden_size_in * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype)
if self.spec_config.num_capture_layers > 1:
self.fc = Linear(self.hidden_size_in *
self.spec_config.num_capture_layers,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype)

self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)

Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_spec_resource_manager)
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import PyTorchConfig
from .config import LoadFormat, PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import GuidedDecoder
from .model_engine import PyTorchModelEngine
Expand Down Expand Up @@ -252,13 +252,16 @@ def create_py_executor(
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
draft_spec_config = copy.copy(spec_config)
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
if spec_config.load_format == "dummy":
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
# The draft model won't have any draft tokens attached to
# generation requests when we invoke it autoregressively
draft_spec_config.max_draft_len = 0

draft_model_engine = PyTorchModelEngine(
model_path=spec_config.speculative_model_dir,
pytorch_backend_config=pytorch_backend_config,
pytorch_backend_config=draft_pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
Expand Down
52 changes: 31 additions & 21 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import List, Optional, Set

import torch
from torch import nn
Expand Down Expand Up @@ -35,9 +35,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
# empty hidden states tensor
max_num_tokens = min(max_num_tokens,
max_num_requests * self.max_seq_len)
self.hidden_states = torch.empty((max_num_tokens, self.hidden_size * 3),
dtype=self.dtype,
device='cuda')
self.hidden_states = torch.empty(
(max_num_tokens, self.hidden_size * config.num_capture_layers),
dtype=self.dtype,
device='cuda')
# sequence length, only used for metadata preparation
self.seq_lens = {i: 0 for i in range(max_num_requests)}
# start indices of each slot
Expand Down Expand Up @@ -79,8 +80,7 @@ def get_needed_resource_to_completion(self, request: LlmRequest):
@dataclass
class Eagle3SpecMetadata(SpecMetadata):
hidden_states: List[torch.Tensor] = field(default_factory=list)
num_capture_layers: int = 3
layers_to_capture: Tuple[int, ...] = field(init=False)
layers_to_capture: Optional[Set[int]] = None
target_model_embed_tokens: Optional[torch.nn.Module] = None
hidden_size: int = 0
max_num_tokens: int = 0
Expand All @@ -90,14 +90,19 @@ class Eagle3SpecMetadata(SpecMetadata):
eagle3_resource_manager: Optional[Eagle3ResourceManager] = None

def __post_init__(self):
if self.num_layers == 1:
self.layers_to_capture = (0, )
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
if self.layers_to_capture is None:
if self.num_layers == 1:
self.layers_to_capture = (self.num_layers - 1, )
else:
if self.num_layers <= 5:
raise ValueError(
"Not enough hidden layers for default EAGLE3 capture")

self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
else:
self.layers_to_capture = sorted(list(self.layers_to_capture))
self.num_capture_layers = len(self.layers_to_capture)

# Initialize to 0 to avoid reading uninitialized memory during warmup
self.hidden_states_read_indices = torch.zeros([self.max_num_tokens],
Expand Down Expand Up @@ -186,7 +191,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
# The hidden states
hidden_states: Optional[torch.Tensor] = None
# The layers to be captured
layers_to_capture: Tuple[int, ...] = field(init=False)
layers_to_capture: Optional[Set[int]] = None
# The hidden size of the hidden states
hidden_size: int = 0
# The max number of tokens
Expand All @@ -197,14 +202,19 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
batch_indices_cuda: Optional[torch.Tensor] = None

def __post_init__(self):
if self.num_layers == 1:
self.layers_to_capture = (1, )
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
if self.layers_to_capture is None:
if self.num_layers == 1:
self.layers_to_capture = (self.num_layers - 1, )
else:
if self.num_layers <= 5:
raise ValueError(
"Not enough hidden layers for default EAGLE3 capture")

self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
else:
self.layers_to_capture = sorted(list(self.layers_to_capture))
self.num_capture_layers = len(self.layers_to_capture)
self.hidden_states = torch.empty(
(self.max_num_tokens,
self.hidden_size * len(self.layers_to_capture)),
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ def create_cuda_graph_metadata(self, max_batch_size: int):
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata

def is_layer_capture(self, layer_id: int):
"""
Whether the layer should be captured (eg for Eagle3).
By default, does nothing.
"""
return False

def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_spec_metadata(spec_config,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSpecMetadata(
Expand All @@ -47,6 +48,7 @@ def get_spec_metadata(spec_config,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture,
)
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
Expand Down
14 changes: 12 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin)

import torch
import yaml
Expand Down Expand Up @@ -352,6 +352,7 @@ class DecodingBaseConfig(StrictBaseModel):
# When specified, speculation will be disabled at batch sizes above
# this value. Otherwise, speculation will always be on.
max_concurrency: Optional[int] = None
load_format: Optional[str] = None

@classmethod
def from_dict(cls, data: dict):
Expand Down Expand Up @@ -424,6 +425,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
num_eagle_layers: Optional[int] = None
max_non_leaves_per_layer: Optional[int] = None
eagle3_one_model: Optional[bool] = True
eagle3_layers_to_capture: Optional[Set[int]] = None

@classmethod
def from_dict(cls, data: dict):
Expand All @@ -443,6 +445,12 @@ def spec_dec_mode(self):
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
return TorchSpeculativeDecodingMode.EAGLE3

@functools.cached_property
def num_capture_layers(self):
if self.eagle3_layers_to_capture is not None:
return len(self.eagle3_layers_to_capture)
return 3


class UserProvidedDecodingConfig(DecodingBaseConfig):
# Cannot use real type annotations due to circular imports
Expand Down Expand Up @@ -523,7 +531,9 @@ class MTPDecodingConfig(DecodingBaseConfig):

@classmethod
def from_dict(cls, data: dict):
return cls(**data)
out = cls(**data)
out.max_draft_len = out.num_nextn_predict_layers
return out

decoding_type: ClassVar[str] = "MTP"

Expand Down
Loading