Skip to content

Commit 4c788d1

Browse files
ameynaik-hubKefeng-Duan
authored andcommitted
Mtp optimizations round1 (NVIDIA#5689)
Signed-off-by: Amey Naik <[email protected]> Co-authored-by: Kefeng-Duan <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 07844ad commit 4c788d1

File tree

4 files changed

+191
-56
lines changed

4 files changed

+191
-56
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,28 +131,38 @@ class DeepseekV3MTPHead(nn.Module):
131131
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
132132
super().__init__()
133133
config = model_config.pretrained_config
134+
self.model_config = model_config
134135

135136
self.norm = RMSNorm(hidden_size=config.hidden_size,
136137
eps=config.rms_norm_eps,
137138
dtype=config.torch_dtype)
138139

140+
@torch.compile(options={"max-autotune": True})
141+
def get_last_token_states(self, hidden_states, attn_metadata):
142+
last_tokens = torch.cumsum(
143+
attn_metadata.seq_lens_cuda,
144+
dim=0,
145+
dtype=torch.long,
146+
) - 1
147+
return hidden_states[last_tokens]
148+
139149
def forward(self,
140150
hidden_states: torch.Tensor,
141151
lm_head: Linear,
142152
attn_metadata: AttentionMetadata,
143153
return_context_logits: bool = False) -> torch.Tensor:
144154
if not return_context_logits:
145155
if attn_metadata is not None:
146-
last_tokens = torch.cumsum(
147-
attn_metadata.seq_lens_cuda,
148-
dim=0,
149-
dtype=torch.long,
150-
) - 1
151-
hidden_states = hidden_states[last_tokens]
156+
hidden_states = self.get_last_token_states(
157+
hidden_states, attn_metadata)
152158
else:
153159
hidden_states = hidden_states[-1].unsqueeze(0)
154160

161+
if not (self.model_config.mapping.enable_attention_dp):
162+
lm_head.gather_output = False
155163
logits = lm_head(hidden_states)
164+
if not (self.model_config.mapping.enable_attention_dp):
165+
lm_head.gather_output = True
156166
return logits
157167

158168

@@ -903,22 +913,40 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
903913
self.num_shared_experts = config.n_shared_experts
904914
self.top_k = config.num_experts_per_tok
905915

916+
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
917+
self.event_dict = {
918+
key: torch.cuda.Event()
919+
for key in [EventType.Main, EventType.MoeShared]
920+
}
921+
906922
self.enorm = RMSNorm(hidden_size=config.hidden_size,
907923
eps=config.rms_norm_eps,
908924
dtype=config.torch_dtype)
909925

910926
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
911927
eps=config.rms_norm_eps,
912928
dtype=config.torch_dtype)
913-
914-
self.eh_proj = Linear(
915-
config.hidden_size * 2,
916-
config.hidden_size,
917-
bias=False,
918-
dtype=config.torch_dtype,
919-
skip_create_weights_in_init=model_config.
920-
skip_create_weights_in_init,
921-
)
929+
if model_config.mapping.enable_attention_dp:
930+
self.eh_proj = Linear(
931+
config.hidden_size * 2,
932+
config.hidden_size,
933+
bias=False,
934+
dtype=config.torch_dtype,
935+
skip_create_weights_in_init=model_config.
936+
skip_create_weights_in_init,
937+
)
938+
else:
939+
self.eh_proj = Linear(
940+
config.hidden_size * 2,
941+
config.hidden_size,
942+
bias=False,
943+
dtype=config.torch_dtype,
944+
tensor_parallel_mode=TensorParallelMode.ROW,
945+
mapping=model_config.mapping,
946+
reduce_output=True,
947+
skip_create_weights_in_init=model_config.
948+
skip_create_weights_in_init,
949+
)
922950

923951
self.shared_head = DeepseekV3MTPHead(model_config)
924952

@@ -934,9 +962,26 @@ def forward(
934962
**kwargs,
935963
) -> Tuple[torch.Tensor, torch.Tensor]:
936964

937-
inputs_embeds = self.enorm(embed_tokens(input_ids))
938-
hidden_states = self.hnorm(hidden_states)
965+
def norm_embeds():
966+
return self.enorm(embed_tokens(input_ids)) #emdedding
967+
968+
def norm_hidden():
969+
return self.hnorm(hidden_states)
970+
971+
inputs_embeds, hidden_states = maybe_execute_in_parallel(
972+
norm_embeds,
973+
norm_hidden,
974+
self.event_dict[EventType.Main],
975+
self.event_dict[EventType.MoeShared],
976+
self.aux_stream,
977+
)
939978
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
979+
# Split hidden_states columnwise based on TP
980+
tp_size = self.model_config.mapping.tp_size
981+
tp_rank = self.model_config.mapping.tp_rank
982+
983+
if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp):
984+
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
940985
hidden_states = self.eh_proj(hidden_states)
941986

942987
# Input layer norm
@@ -1074,7 +1119,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10741119
self.model.aux_stream_dict)
10751120
self.model.layers.append(mtp_layer)
10761121
self.epilogue.append(mtp_layer)
1077-
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
1122+
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
1123+
model_config)
10781124
else:
10791125
mtp_layers = nn.ModuleList([
10801126
DeepseekV3MTP(model_config,
@@ -1084,7 +1130,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10841130
])
10851131
self.model.layers.extend(mtp_layers)
10861132
self.epilogue.extend(mtp_layers)
1087-
self.mtp_worker = MTPWorker(model_config.spec_config)
1133+
self.mtp_worker = MTPWorker(model_config.spec_config,
1134+
model_config)
10881135
# modify the QuantConfig to support duplicated mtp layers
10891136
if model_config.quant_config.exclude_modules is not None:
10901137
extend_exclude_modules = []

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
359359

360360
self.draft_model = get_draft_model(model_config, draft_config)
361361
self.spec_worker = get_spec_worker(model_config.spec_config,
362+
model_config,
362363
model_config.mapping)
363364

364365
def forward(

0 commit comments

Comments
 (0)