Skip to content

Commit f9a5215

Browse files
mikeiovinedominicshanshan
authored andcommitted
[None][fix] Fix MTP 2-model (NVIDIA#8115)
Signed-off-by: Mike Iovine <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent ec5e2c2 commit f9a5215

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
5454
# whether the next draft forward is the first
5555
self.is_first_draft = True
5656
self.spec_tree_manager = None
57-
if config.eagle_choices is not None:
57+
58+
if isinstance(config,
59+
EagleDecodingConfig) and config.eagle_choices is not None:
5860
self.spec_tree_manager = SpecTreeManager(
5961
max_num_requests=self.max_num_requests,
6062
use_dynamic_tree=config.use_dynamic_tree,

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def needs_kv_cache_rewind(self):
6767
) or self.is_ngram()
6868

6969
def support_overlap_scheduler(self):
70+
# TODO: fix accuracy issue
71+
if self.is_mtp_eagle():
72+
return False
73+
7074
return self.is_mtp_one_model() or self.is_eagle3_one_model(
7175
) or self.has_draft_model()
7276

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def get_spec_metadata(spec_config,
2828
max_num_requests=max_num_requests,
2929
mtp_hidden_states_manager=spec_resource_manager,
3030
)
31+
if spec_config.spec_dec_mode.is_mtp_eagle():
32+
return Eagle3SpecMetadata(
33+
max_draft_len=spec_config.max_draft_len,
34+
spec_dec_mode=spec_config.spec_dec_mode,
35+
max_num_requests=max_num_requests,
36+
num_layers=model_config.num_hidden_layers,
37+
hidden_size=model_config.hidden_size,
38+
max_num_tokens=max_num_tokens,
39+
dtype=model_config.torch_dtype,
40+
is_draft_model=is_draft_model,
41+
eagle3_resource_manager=spec_resource_manager,
42+
layers_to_capture=None,
43+
is_mtp_eagle=True,
44+
)
3145
if spec_config.spec_dec_mode.is_eagle3():
3246
return Eagle3SpecMetadata(
3347
max_draft_len=spec_config.max_draft_len,

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ l0_b200:
5555
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
5656
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
5757
- test_e2e.py::test_ptp_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
58+
- test_e2e.py::test_ptp_quickstart_advanced_mtp_eagle[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
5859
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
5960
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
6061
- test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct]

0 commit comments

Comments
 (0)