Skip to content

Commit 2bcb3ac

Browse files
committed
mtp round1 optimizations without AR+norm fusion in mtp head
Signed-off-by: Amey Naik <[email protected]>
1 parent 1191555 commit 2bcb3ac

File tree

3 files changed

+214
-55
lines changed

3 files changed

+214
-55
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,31 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
136136
eps=config.rms_norm_eps,
137137
dtype=config.torch_dtype)
138138

139+
@torch.compile(mode="max-autotune-no-cudagraphs")
140+
def get_last_token_states(self, hidden_states, attn_metadata):
141+
last_tokens = torch.cumsum(
142+
attn_metadata.seq_lens_cuda,
143+
dim=0,
144+
dtype=torch.long,
145+
) - 1
146+
return hidden_states[last_tokens]
147+
139148
def forward(self,
140149
hidden_states: torch.Tensor,
141150
lm_head: Linear,
142151
attn_metadata: AttentionMetadata,
143152
return_context_logits: bool = False) -> torch.Tensor:
144153
if not return_context_logits:
145154
if attn_metadata is not None:
146-
last_tokens = torch.cumsum(
147-
attn_metadata.seq_lens_cuda,
148-
dim=0,
149-
dtype=torch.long,
150-
) - 1
151-
hidden_states = hidden_states[last_tokens]
155+
hidden_states = self.get_last_token_states(
156+
hidden_states, attn_metadata)
152157
else:
153158
hidden_states = hidden_states[-1].unsqueeze(0)
154159

160+
lm_head.gather_output = False
155161
logits = lm_head(hidden_states)
162+
# print("AMEYN: inside DeepseekV3MTPHead lm_head logits.shape:", logits.shape)
163+
lm_head.gather_output = True
156164
return logits
157165

158166

@@ -911,22 +919,46 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
911919
self.num_shared_experts = config.n_shared_experts
912920
self.top_k = config.num_experts_per_tok
913921

922+
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
923+
self.event_dict = {
924+
key: torch.cuda.Event()
925+
for key in [EventType.Main, EventType.MoeShared]
926+
}
927+
914928
self.enorm = RMSNorm(hidden_size=config.hidden_size,
915929
eps=config.rms_norm_eps,
916930
dtype=config.torch_dtype)
917931

918932
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
919933
eps=config.rms_norm_eps,
920934
dtype=config.torch_dtype)
935+
self.fuse_norm_ar = False #FIXME: AMEYN
936+
if self.fuse_norm_ar:
937+
self.eh_proj = Linear(
938+
config.hidden_size * 2,
939+
config.hidden_size,
940+
bias=False,
941+
dtype=config.torch_dtype,
942+
tensor_parallel_mode=TensorParallelMode.ROW,
943+
mapping=model_config.mapping,
944+
reduce_output=False,
945+
skip_create_weights_in_init=model_config.
946+
skip_create_weights_in_init,
947+
)
948+
else:
949+
self.eh_proj = Linear(
950+
config.hidden_size * 2,
951+
config.hidden_size,
952+
bias=False,
953+
dtype=config.torch_dtype,
954+
tensor_parallel_mode=TensorParallelMode.ROW,
955+
mapping=model_config.mapping,
956+
reduce_output=True,
957+
skip_create_weights_in_init=model_config.
958+
skip_create_weights_in_init,
959+
)
921960

922-
self.eh_proj = Linear(
923-
config.hidden_size * 2,
924-
config.hidden_size,
925-
bias=False,
926-
dtype=config.torch_dtype,
927-
skip_create_weights_in_init=model_config.
928-
skip_create_weights_in_init,
929-
)
961+
# Print shared head initialization message only for rank 0
930962

931963
self.shared_head = DeepseekV3MTPHead(model_config)
932964

@@ -942,14 +974,41 @@ def forward(
942974
**kwargs,
943975
) -> Tuple[torch.Tensor, torch.Tensor]:
944976

945-
inputs_embeds = self.enorm(embed_tokens(input_ids))
946-
hidden_states = self.hnorm(hidden_states)
977+
def norm_embeds():
978+
return self.enorm(embed_tokens(input_ids)) #emdedding
979+
980+
def norm_hidden():
981+
return self.hnorm(hidden_states)
982+
983+
inputs_embeds, hidden_states = maybe_execute_in_parallel(
984+
norm_embeds,
985+
norm_hidden,
986+
self.event_dict[EventType.Main],
987+
self.event_dict[EventType.MoeShared],
988+
self.aux_stream,
989+
)
947990
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
991+
# Split hidden_states columnwise based on TP
992+
tp_size = self.model_config.mapping.tp_size
993+
tp_rank = self.model_config.mapping.tp_rank
994+
if tp_size > 1:
995+
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
948996
hidden_states = self.eh_proj(hidden_states)
949997

950998
# Input layer norm
951-
residual = hidden_states
952-
hidden_states = self.input_layernorm(hidden_states)
999+
if self.fuse_norm_ar:
1000+
hidden_states, residual = self.allreduce(
1001+
hidden_states,
1002+
all_reduce_params=AllReduceParams(
1003+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
1004+
residual=torch.zeros_like(hidden_states),
1005+
norm_weight=self.input_layernorm.weight,
1006+
eps=self.input_layernorm.variance_epsilon,
1007+
),
1008+
)
1009+
else:
1010+
residual = hidden_states
1011+
hidden_states = self.input_layernorm(hidden_states)
9531012

9541013
# Self Attention
9551014
hidden_states = self.self_attn(
@@ -1082,7 +1141,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10821141
self.model.aux_stream_dict)
10831142
self.model.layers.append(mtp_layer)
10841143
self.epilogue.append(mtp_layer)
1085-
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
1144+
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
1145+
model_config)
10861146
else:
10871147
mtp_layers = nn.ModuleList([
10881148
DeepseekV3MTP(model_config,
@@ -1092,7 +1152,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10921152
])
10931153
self.model.layers.extend(mtp_layers)
10941154
self.epilogue.extend(mtp_layers)
1095-
self.mtp_worker = MTPWorker(model_config.spec_config)
1155+
self.mtp_worker = MTPWorker(model_config.spec_config,
1156+
model_config)
10961157
# modify the QuantConfig to support duplicated mtp layers
10971158
if model_config.quant_config.exclude_modules is not None:
10981159
extend_exclude_modules = []

0 commit comments

Comments
 (0)