Skip to content

Commit 32e26e3

Browse files
IzzyPuttermanWong4j
authored andcommitted
[None][feat] Eagle, use last hidden post norm (NVIDIA#7546)
Signed-off-by: Izzy Putterman <[email protected]>
1 parent de4006e commit 32e26e3

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
410410
assert key in ('attn_layers', 'mla_layers')
411411
assert key in model_config.extra_attrs
412412
model_config.extra_attrs[key].update(value)
413+
self.layer_idx = -1
413414

414415
def forward(
415416
self,
@@ -430,6 +431,10 @@ def forward(
430431
**kwargs,
431432
)
432433

434+
if spec_metadata is not None and spec_metadata.is_layer_capture(
435+
self.layer_idx):
436+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
437+
hidden_states)
433438
if attn_metadata.padded_num_tokens is not None:
434439
hidden_states = hidden_states[:attn_metadata.num_tokens]
435440

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ class Eagle3SpecMetadata(SpecMetadata):
9494
eagle3_resource_manager: Optional[Eagle3ResourceManager] = None
9595

9696
def __post_init__(self):
97-
if self.layers_to_capture is None:
98-
if self.is_draft_model or self.num_layers == 1:
97+
if self.is_draft_model:
98+
self.layers_to_capture = (self.num_layers - 1, )
99+
elif self.layers_to_capture is None:
100+
if self.num_layers == 1:
99101
self.layers_to_capture = (self.num_layers - 1, )
100102
else:
101103
if self.num_layers <= 5:

0 commit comments

Comments
 (0)