Skip to content

Commit f9fc02b

Browse files
committed
mtp round1 optimizations without AR+norm fusion in mtp head
Signed-off-by: Amey Naik <[email protected]>
1 parent 6304866 commit f9fc02b

File tree

3 files changed

+214
-55
lines changed

3 files changed

+214
-55
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,31 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
136136
eps=config.rms_norm_eps,
137137
dtype=config.torch_dtype)
138138

139+
@torch.compile(mode="max-autotune-no-cudagraphs")
140+
def get_last_token_states(self, hidden_states, attn_metadata):
141+
last_tokens = torch.cumsum(
142+
attn_metadata.seq_lens_cuda,
143+
dim=0,
144+
dtype=torch.long,
145+
) - 1
146+
return hidden_states[last_tokens]
147+
139148
def forward(self,
140149
hidden_states: torch.Tensor,
141150
lm_head: Linear,
142151
attn_metadata: AttentionMetadata,
143152
return_context_logits: bool = False) -> torch.Tensor:
144153
if not return_context_logits:
145154
if attn_metadata is not None:
146-
last_tokens = torch.cumsum(
147-
attn_metadata.seq_lens_cuda,
148-
dim=0,
149-
dtype=torch.long,
150-
) - 1
151-
hidden_states = hidden_states[last_tokens]
155+
hidden_states = self.get_last_token_states(
156+
hidden_states, attn_metadata)
152157
else:
153158
hidden_states = hidden_states[-1].unsqueeze(0)
154159

160+
lm_head.gather_output = False
155161
logits = lm_head(hidden_states)
162+
# print("AMEYN: inside DeepseekV3MTPHead lm_head logits.shape:", logits.shape)
163+
lm_head.gather_output = True
156164
return logits
157165

158166

@@ -913,22 +921,46 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
913921
self.num_shared_experts = config.n_shared_experts
914922
self.top_k = config.num_experts_per_tok
915923

924+
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
925+
self.event_dict = {
926+
key: torch.cuda.Event()
927+
for key in [EventType.Main, EventType.MoeShared]
928+
}
929+
916930
self.enorm = RMSNorm(hidden_size=config.hidden_size,
917931
eps=config.rms_norm_eps,
918932
dtype=config.torch_dtype)
919933

920934
self.hnorm = RMSNorm(hidden_size=config.hidden_size,
921935
eps=config.rms_norm_eps,
922936
dtype=config.torch_dtype)
937+
self.fuse_norm_ar = False #FIXME: AMEYN
938+
if self.fuse_norm_ar:
939+
self.eh_proj = Linear(
940+
config.hidden_size * 2,
941+
config.hidden_size,
942+
bias=False,
943+
dtype=config.torch_dtype,
944+
tensor_parallel_mode=TensorParallelMode.ROW,
945+
mapping=model_config.mapping,
946+
reduce_output=False,
947+
skip_create_weights_in_init=model_config.
948+
skip_create_weights_in_init,
949+
)
950+
else:
951+
self.eh_proj = Linear(
952+
config.hidden_size * 2,
953+
config.hidden_size,
954+
bias=False,
955+
dtype=config.torch_dtype,
956+
tensor_parallel_mode=TensorParallelMode.ROW,
957+
mapping=model_config.mapping,
958+
reduce_output=True,
959+
skip_create_weights_in_init=model_config.
960+
skip_create_weights_in_init,
961+
)
923962

924-
self.eh_proj = Linear(
925-
config.hidden_size * 2,
926-
config.hidden_size,
927-
bias=False,
928-
dtype=config.torch_dtype,
929-
skip_create_weights_in_init=model_config.
930-
skip_create_weights_in_init,
931-
)
963+
# Print shared head initialization message only for rank 0
932964

933965
self.shared_head = DeepseekV3MTPHead(model_config)
934966

@@ -944,14 +976,41 @@ def forward(
944976
**kwargs,
945977
) -> Tuple[torch.Tensor, torch.Tensor]:
946978

947-
inputs_embeds = self.enorm(embed_tokens(input_ids))
948-
hidden_states = self.hnorm(hidden_states)
979+
def norm_embeds():
980+
return self.enorm(embed_tokens(input_ids)) #emdedding
981+
982+
def norm_hidden():
983+
return self.hnorm(hidden_states)
984+
985+
inputs_embeds, hidden_states = maybe_execute_in_parallel(
986+
norm_embeds,
987+
norm_hidden,
988+
self.event_dict[EventType.Main],
989+
self.event_dict[EventType.MoeShared],
990+
self.aux_stream,
991+
)
949992
hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1)
993+
# Split hidden_states columnwise based on TP
994+
tp_size = self.model_config.mapping.tp_size
995+
tp_rank = self.model_config.mapping.tp_rank
996+
if tp_size > 1:
997+
hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank]
950998
hidden_states = self.eh_proj(hidden_states)
951999

9521000
# Input layer norm
953-
residual = hidden_states
954-
hidden_states = self.input_layernorm(hidden_states)
1001+
if self.fuse_norm_ar:
1002+
hidden_states, residual = self.allreduce(
1003+
hidden_states,
1004+
all_reduce_params=AllReduceParams(
1005+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
1006+
residual=torch.zeros_like(hidden_states),
1007+
norm_weight=self.input_layernorm.weight,
1008+
eps=self.input_layernorm.variance_epsilon,
1009+
),
1010+
)
1011+
else:
1012+
residual = hidden_states
1013+
hidden_states = self.input_layernorm(hidden_states)
9551014

9561015
# Self Attention
9571016
hidden_states = self.self_attn(
@@ -1084,7 +1143,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10841143
self.model.aux_stream_dict)
10851144
self.model.layers.append(mtp_layer)
10861145
self.epilogue.append(mtp_layer)
1087-
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
1146+
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
1147+
model_config)
10881148
else:
10891149
mtp_layers = nn.ModuleList([
10901150
DeepseekV3MTP(model_config,
@@ -1094,7 +1154,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10941154
])
10951155
self.model.layers.extend(mtp_layers)
10961156
self.epilogue.extend(mtp_layers)
1097-
self.mtp_worker = MTPWorker(model_config.spec_config)
1157+
self.mtp_worker = MTPWorker(model_config.spec_config,
1158+
model_config)
10981159
# modify the QuantConfig to support duplicated mtp layers
10991160
if model_config.quant_config.exclude_modules is not None:
11001161
extend_exclude_modules = []

0 commit comments

Comments
 (0)