Skip to content

Commit 1f65f9b

Browse files
committed
Additional small changes
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 5888f12 commit 1f65f9b

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,7 @@ def forward(
10491049
position_ids: Optional[torch.IntTensor] = None,
10501050
inputs_embeds: Optional[torch.FloatTensor] = None,
10511051
spec_metadata: Optional[SpecMetadata] = None,
1052+
**kwargs,
10521053
) -> torch.Tensor:
10531054
if (input_ids is None) ^ (inputs_embeds is not None):
10541055
raise ValueError(
@@ -1145,6 +1146,7 @@ def forward(
11451146
)
11461147

11471148
if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
1149+
# TODO Merge API with EagleWorker in modeling_speculative.py
11481150
# get logits
11491151
logits = self.logits_processor.forward(
11501152
hidden_states[spec_metadata.gather_ids],

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def get_hidden_states(self):
186186
class Eagle3OneModelSpecMetadata(SpecMetadata):
187187
# The hidden states
188188
hidden_states: Optional[torch.Tensor] = None
189+
# The number of layers to be captured
190+
num_capture_layers: int = 3
189191
# The layers to be captured
190192
layers_to_capture: Tuple[int, ...] = field(init=False)
191193
# The hidden size of the hidden states
@@ -198,8 +200,8 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
198200
batch_indices_cuda: Optional[torch.Tensor] = None
199201

200202
def __post_init__(self):
201-
if self.num_layers == 1:
202-
self.layers_to_capture = (1, )
203+
if self.num_layers == 1 or self.num_capture_layers == 1:
204+
self.layers_to_capture = (self.num_layers - 1, )
203205
else:
204206
if self.num_layers <= 5:
205207
raise ValueError("Not enough hidden layers for EAGLE")

0 commit comments

Comments
 (0)