Skip to content

Commit f2657c1

Browse files
[None][fix] Eagle: Attention DP (#7939)
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 3492391 commit f2657c1

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,16 @@ def __init__(
4848
)
4949

5050
tp_size = model_config.mapping.tp_size
51+
if model_config.mapping.enable_attention_dp:
52+
tp_size = 1
5153
# Override the QKV projection. The number of input features
5254
# is twice as big for EAGLE3 draft models.
5355
self.qkv_proj = Linear(
5456
2 * self.hidden_size,
5557
tp_size * self.q_size + 2 * tp_size * self.kv_size,
5658
bias=config.attention_bias,
5759
dtype=config.torch_dtype,
58-
mapping=model_config.mapping,
60+
mapping=self.qkv_proj.mapping,
5961
tensor_parallel_mode=TensorParallelMode.COLUMN,
6062
weights_loading_config=WeightsLoadingConfig(
6163
weight_mode=WeightMode.FUSED_QKV_LINEAR),
@@ -89,6 +91,8 @@ def __init__(
8991
bias=getattr(config, "mlp_bias", False),
9092
dtype=config.torch_dtype,
9193
config=model_config,
94+
overridden_tp_size=1
95+
if model_config.mapping.enable_attention_dp else None,
9296
)
9397
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
9498
eps=config.rms_norm_eps,
@@ -182,14 +186,21 @@ def __init__(
182186
requires_grad=False)
183187

184188
if self.hidden_size_in != config.hidden_size:
185-
self.embed_tokens = Embedding(
186-
config.vocab_size,
187-
config.hidden_size,
188-
dtype=config.torch_dtype,
189-
mapping=model_config.mapping,
190-
tensor_parallel_mode=TensorParallelMode.COLUMN,
191-
gather_output=True,
192-
)
189+
if model_config.mapping.enable_attention_dp:
190+
self.embed_tokens = Embedding(
191+
config.vocab_size,
192+
config.hidden_size,
193+
dtype=config.torch_dtype,
194+
)
195+
else:
196+
self.embed_tokens = Embedding(
197+
config.vocab_size,
198+
config.hidden_size,
199+
dtype=config.torch_dtype,
200+
mapping=model_config.mapping,
201+
tensor_parallel_mode=TensorParallelMode.COLUMN,
202+
gather_output=True,
203+
)
193204
else:
194205
# Shared with target model.
195206
self.embed_tokens = None

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,40 @@ def enforce_single_worker(monkeypatch):
2424

2525

2626
@pytest.mark.parametrize(
27-
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
27+
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
2828
[
29-
[True, "TRTLLM", True, False, False, False, True, False],
30-
[True, "TRTLLM", True, False, False, False, False, False],
31-
[False, "TRTLLM", True, False, False, False, True, False],
32-
[False, "TRTLLM", True, False, False, False, False, False],
33-
[True, "FLASHINFER", True, False, False, False, True, False],
34-
[False, "FLASHINFER", True, False, False, False, True, False],
35-
[False, "TRTLLM", False, True, True, False, True, False],
36-
[True, "TRTLLM", False, True, True, False, True, False],
37-
[True, "TRTLLM", True, False, True, True, True, False],
38-
[True, "TRTLLM", True, False, True, False, True, False],
29+
[True, "TRTLLM", True, False, False, False, True, False, False],
30+
[True, "TRTLLM", True, False, False, False, False, False, False],
31+
[False, "TRTLLM", True, False, False, False, True, False, False],
32+
[False, "TRTLLM", True, False, False, False, False, False, False],
33+
[True, "FLASHINFER", True, False, False, False, True, False, False],
34+
[False, "FLASHINFER", True, False, False, False, True, False, False],
35+
[False, "TRTLLM", False, True, True, False, True, False, False],
36+
[True, "TRTLLM", False, True, True, False, True, False, False],
37+
[True, "TRTLLM", True, False, True, True, True, False, False],
38+
[True, "TRTLLM", True, False, True, False, True, False, False],
3939
# TODO: nvbugs/5461761
4040
# [True, "TRTLLM", True, False, False, True, True, False],
41-
[True, "TRTLLM", False, False, False, False, True, False],
42-
[False, "TRTLLM", False, False, False, False, True, False],
43-
[True, "TRTLLM", False, False, False, False, False, True],
44-
[False, "TRTLLM", False, False, False, False, False, True],
45-
[True, "TRTLLM", False, False, False, False, True, True],
46-
[False, "TRTLLM", False, False, False, False, True, True],
47-
[True, "TRTLLM", False, False, False, False, False, False],
48-
[False, "TRTLLM", False, False, False, False, False, False],
49-
[True, "TRTLLM", False, False, False, True, True, False],
50-
[True, "TRTLLM", False, False, False, True, False, False],
51-
[True, "FLASHINFER", False, False, False, False, True, False],
52-
[False, "FLASHINFER", False, False, False, False, True, False],
41+
[True, "TRTLLM", False, False, False, False, True, False, False],
42+
[False, "TRTLLM", False, False, False, False, True, False, False],
43+
[True, "TRTLLM", False, False, False, False, False, True, False],
44+
[True, "TRTLLM", False, False, False, False, False, True, True],
45+
[False, "TRTLLM", False, False, False, False, False, True, False],
46+
[True, "TRTLLM", False, False, False, False, True, True, False],
47+
[False, "TRTLLM", False, False, False, False, True, True, False],
48+
[True, "TRTLLM", False, False, False, False, False, False, False],
49+
[False, "TRTLLM", False, False, False, False, False, False, False],
50+
[True, "TRTLLM", False, False, False, True, True, False, False],
51+
[True, "TRTLLM", False, False, False, True, False, False, False],
52+
[True, "FLASHINFER", False, False, False, False, True, False, False],
53+
[False, "FLASHINFER", False, False, False, False, True, False, False],
5354
])
5455
@pytest.mark.high_cuda_memory
5556
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
5657
disable_overlap_scheduler: bool, enable_block_reuse: bool,
5758
use_one_model: bool, enable_chunked_prefill: bool,
58-
use_chain_drafter: bool, multi_batch: bool, request):
59+
use_chain_drafter: bool, multi_batch: bool,
60+
attention_dp: bool, request):
5961
# Use enforce_single_worker fixture only when use_chain_drafter is False.
6062
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6163
if not use_chain_drafter:
@@ -98,6 +100,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
98100
cuda_graph_config=cuda_graph_config,
99101
max_batch_size=max_batch_size,
100102
kv_cache_config=kv_cache_config,
103+
enable_attention_dp=attention_dp,
101104
# This max_seq_len is larger than the one specified
102105
# in the llama 3 8B eagle's config. We want to make sure
103106
# that the draft model won't go above its max in warmup

0 commit comments

Comments
 (0)