From b4a468245b8b5ce16e36bd37d30cce47e46b5eea Mon Sep 17 00:00:00 2001 From: qgai Date: Mon, 18 Aug 2025 13:35:19 +0000 Subject: [PATCH 1/2] support mtp eagle with 2 models style Signed-off-by: qgai --- examples/llm-api/quickstart_advanced.py | 9 +- tensorrt_llm/_torch/models/modeling_auto.py | 2 + .../_torch/models/modeling_deepseekv3.py | 2265 +++++++++-------- .../_torch/models/modeling_speculative.py | 120 +- .../_torch/pyexecutor/py_executor_creator.py | 3 + tensorrt_llm/_torch/speculative/eagle3.py | 4 +- tensorrt_llm/_torch/speculative/interface.py | 28 +- .../_torch/speculative/model_drafter.py | 2 +- tensorrt_llm/_torch/speculative/mtp.py | 6 +- tensorrt_llm/_torch/speculative/utils.py | 42 +- tensorrt_llm/llmapi/llm_args.py | 13 +- tests/integration/defs/test_e2e.py | 30 + 12 files changed, 1362 insertions(+), 1162 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 61240b496de..14587f9e883 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -170,16 +170,15 @@ def setup_llm(args, **kwargs): if spec_decode_algo == 'MTP': if not args.use_one_model: - print( - "MTP only supports one model style spec decode; ignoring default use_one_model=False" - ) - + print("Running MTP eagle with two model style.") spec_config = MTPDecodingConfig( num_nextn_predict_layers=args.spec_decode_max_draft_len, use_relaxed_acceptance_for_thinking=args. use_relaxed_acceptance_for_thinking, relaxed_topk=args.relaxed_topk, - relaxed_delta=args.relaxed_delta) + relaxed_delta=args.relaxed_delta, + mtp_eagle_one_model=args.use_one_model, + speculative_model_dir=args.model_dir) elif spec_decode_algo == "EAGLE3": spec_config = EagleDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, diff --git a/tensorrt_llm/_torch/models/modeling_auto.py b/tensorrt_llm/_torch/models/modeling_auto.py index cd73919ca2c..5788a9b2a5c 100644 --- a/tensorrt_llm/_torch/models/modeling_auto.py +++ b/tensorrt_llm/_torch/models/modeling_auto.py @@ -31,6 +31,8 @@ def from_config( "") # Strip the appended EAGLE3 if hasattr(config.pretrained_config, "draft_vocab_size"): model_arch = "EAGLE3" + model_arch + if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0: + model_arch = "MTPDraftModelForCausalLM" cls = MODEL_CLASS_MAPPING.get(model_arch) if cls is None: diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index e8c6106b4b2..ae008978437 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -136,638 +136,681 @@ def moe_reduce_add_shared_output(routed_output, shared_output): return shared_output + routed_output -class DeepseekV3MTPHead(nn.Module): +class DeepseekV3WeightLoader: - def __init__(self, model_config: ModelConfig[PretrainedConfig]): - super().__init__() - config = model_config.pretrained_config - self.model_config = model_config + def __init__(self, model, is_draft_model: bool = False): + self.model = model + self.config = model.config + self.model_config = model.model_config + self.is_draft_model = is_draft_model - self.norm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - if self.model_config.mapping.enable_attention_dp and \ - getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): - self.mapping_lm_head_tp = create_lm_head_tp_mapping( - self.model_config.mapping) - else: - self.mapping_lm_head_tp = self.model_config.mapping + def load_weights(self, weights: Dict): - @torch.compile(options={"max-autotune": True}) - def get_last_token_states(self, hidden_states, attn_metadata): - last_tokens = torch.cumsum( - attn_metadata.seq_lens_cuda, - dim=0, - dtype=torch.long, - ) - 1 - return hidden_states[last_tokens] + def rename_moe_weight(weights: Dict, rename_rules: Dict): + result = {} + for key, value in weights.items(): + new_key = key + for old, new in rename_rules.items(): + new_key = new_key.replace(old, new) + result[new_key] = value + return result - def forward(self, - hidden_states: torch.Tensor, - lm_head: Linear, - attn_metadata: AttentionMetadata, - return_context_logits: bool = False) -> torch.Tensor: - if not return_context_logits: - if attn_metadata is not None: - hidden_states = self.get_last_token_states( - hidden_states, attn_metadata) - else: - hidden_states = hidden_states[-1].unsqueeze(0) + ## Prepare weights for TP + def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return torch.chunk(v, tp_size)[idx].contiguous() + return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() - enable_attention_dp = self.model_config.mapping.enable_attention_dp - enable_lm_head_tp_in_adp = self.model_config.mapping.enable_lm_head_tp_in_adp + def split_matrix_tp(v, tensor_parallel, rank, dim): + return split(v, tensor_parallel, rank, dim=dim) - # Add pre-lm gather logic - if enable_lm_head_tp_in_adp: - # ADP + LM TP mode: perform All-Gather before LM_head - hidden_states = allgather(hidden_states, - self.mapping_lm_head_tp, - dim=0) + def load_kv_b_proj_and_k_b_proj_trans(module_name: str, + is_scale: bool) -> torch.Tensor: + weight_name = "weight" if not is_scale else "weight_scale_inv" + local_qk_nope_head_dim = qk_nope_head_dim if not is_scale else qk_nope_head_dim // 128 + local_v_head_dim = v_head_dim if not is_scale else v_head_dim // 128 + local_kv_lora_rank = kv_lora_rank if not is_scale else kv_lora_rank // 128 - # Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled) - if not enable_attention_dp or enable_lm_head_tp_in_adp: - lm_head.gather_output = False - logits = lm_head(hidden_states, is_spec_decoding_head=True) - if not enable_attention_dp or enable_lm_head_tp_in_adp: - lm_head.gather_output = True - return logits + kv_b_proj = weights[f"{module_name}.{weight_name}"][:].unflatten( + 0, + [ + num_heads, + local_qk_nope_head_dim + local_v_head_dim, + ], + ) + if not self.model_config.mapping.enable_attention_dp: + kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) + k_nope_weight, v_weight = kv_b_proj.split( + [local_qk_nope_head_dim, local_v_head_dim], + dim=1, + ) + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor -class DeepseekV3Linear(Linear): - """ - A wrapper around Linear because we may optionally use min-latency kernels depending on input shapes. - """ + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - mapping: Optional[Mapping] = None, - tensor_parallel_mode: Optional[TensorParallelMode] = None, - gather_output: bool = False, # COLUMN parallel only - quant_config: Optional[QuantConfig] = None, - weights_loading_config: Optional[WeightsLoadingConfig] = None, - reduce_output: bool = True, # ROW parallel only - skip_create_weights_in_init: bool = False, - use_custom_cublas_mm: bool = False, - lora: Optional[LoraLayer] = None, - ): - super().__init__( - in_features, - out_features, - bias, - dtype, - mapping, - tensor_parallel_mode, - gather_output, - quant_config, - weights_loading_config, - reduce_output, - skip_create_weights_in_init, - use_custom_cublas_mm, - lora, - ) + kv_b_proj = torch.concat([ + k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, + local_kv_lora_rank), + v_weight.reshape(local_num_heads * local_v_head_dim, + local_kv_lora_rank) + ], + dim=0) - def apply_linear(self, - input, - bias, - lora_params: Optional[dict] | None = None, - layer_idx: Optional[int] | None = None): - num_tokens = input.shape[0] - if (not self.has_any_quant and 1 <= num_tokens <= 16 - and get_sm_version() != 120): - output = torch.ops.trtllm.dsv3_fused_a_gemm_op( - input, self.weight.t(), bias, None) - else: - output = super().apply_linear(input, bias, lora_params, layer_idx) - return output + return kv_b_proj, k_nope_weight_trans + def load_kv_b_proj_and_k_b_proj_trans_dequant( + module_name: str) -> torch.Tensor: + weight_name = "weight" + local_qk_nope_head_dim = qk_nope_head_dim + local_v_head_dim = v_head_dim + local_kv_lora_rank = kv_lora_rank -class DeepseekV3Attention(MLA): + kv_b_proj = weights[f"{module_name}.{weight_name}"][:].cuda() - def __init__( - self, - model_config: ModelConfig[PretrainedConfig], - layer_idx: Optional[int] = None, - aux_stream: Optional[torch.cuda.Stream] = None, - ): - config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 - super().__init__(hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - qk_rope_head_dim=config.qk_rope_head_dim, - qk_nope_head_dim=config.qk_nope_head_dim, - q_lora_rank=config.q_lora_rank, - kv_lora_rank=config.kv_lora_rank, - v_head_dim=config.v_head_dim, - predicted_tokens_per_seq=predicted_tokens_per_seq, - max_position_embeddings=config.max_position_embeddings, - bias=False, - pos_embd_params=PositionalEmbeddingParams( - type=PositionEmbeddingType.yarn, - rope=RopeParams.from_config(config), - is_neox=False, - ), - layer_idx=layer_idx, - dtype=config.torch_dtype, - config=model_config, - aux_stream=aux_stream) - self.kv_a_proj_with_mqa = DeepseekV3Linear( - config.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim + - (self.q_lora_rank if not self.is_lite else 0), - bias=False, - dtype=config.torch_dtype, - quant_config=model_config.get_quant_config(), - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - use_custom_cublas_mm=True) + weight_name = "weight_scale_inv" + kv_b_proj_scale = weights[f"{module_name}.{weight_name}"][:].cuda() + kv_b_proj = weight_dequant(kv_b_proj, kv_b_proj_scale) + kv_b_proj = kv_b_proj.unflatten( + 0, + [ + num_heads, + local_qk_nope_head_dim + local_v_head_dim, + ], + ) + if not self.model_config.mapping.enable_attention_dp: + kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) + k_nope_weight, v_weight = kv_b_proj.split( + [local_qk_nope_head_dim, local_v_head_dim], + dim=1, + ) + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor -class Deepseekv3RoutingImpl(): + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() - def __init__( - self, - top_k: int, - n_group: int, - topk_group: int, - routed_scaling_factor: float, - is_fused: bool = True, - ): - super().__init__() - self.top_k = top_k - self.topk_group = topk_group - self.n_group = n_group - self.routed_scaling_factor = routed_scaling_factor - self.is_fused = is_fused + kv_b_proj = torch.concat([ + k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, + local_kv_lora_rank), + v_weight.reshape(local_num_heads * local_v_head_dim, + local_kv_lora_rank) + ], + dim=0) - @torch.compile(options={"max-autotune": True}) - def get_scores(self, logits, e_score_correction_bias): - scores = F.sigmoid(logits) - scores_with_bias = scores + e_score_correction_bias - return scores, scores_with_bias + return kv_b_proj, k_nope_weight_trans - def noaux_tc(self, logits, e_score_correction_bias): - n_group = self.n_group - scores, scores_with_bias = self.get_scores(logits, - e_score_correction_bias) - scores_shape = list(scores_with_bias.shape) + def split_kv_b_proj(kv_b_proj: torch.Tensor, + is_scale: bool) -> torch.Tensor: + local_qk_nope_head_dim = qk_nope_head_dim if not is_scale else qk_nope_head_dim // 128 + local_v_head_dim = v_head_dim if not is_scale else v_head_dim // 128 - if enable_llm_debug(): - has_nan = torch.isnan(scores_with_bias).any() - if has_nan: - warnings.warn( - "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation." - ) + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor - if not self.is_fused: - group_scores = torch.sum(torch.topk( - scores_with_bias.view(scores_shape[:-1] + - [n_group, scores_shape[-1] // n_group]), - k=2, - dim=-1, - largest=True, - sorted=True)[0], - dim=-1) - _, group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - largest=True, - sorted=True) - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(-1, group_idx, 1) - score_mask = group_mask.unsqueeze(-1).expand( - scores_shape[:-1] + - [n_group, scores_shape[-1] // n_group]).reshape(scores_shape) - scores_with_bias = scores_with_bias * score_mask - _, topk_idx = torch.topk(scores_with_bias, - k=self.top_k, - dim=-1, - largest=True, - sorted=True) - new_mask = torch.zeros_like(scores) - new_mask.scatter_(-1, topk_idx, 1) - scores = scores * new_mask - score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 - scores = scores / score_sum * \ - self.routed_scaling_factor - topk_values, topk_indices = torch.topk(scores, - k=self.top_k, - dim=-1, - largest=True) - return topk_values, topk_indices - else: - topk_values, topk_indices = torch.ops.trtllm.noaux_tc_op( - scores, scores_with_bias, n_group, self.topk_group, self.top_k, - self.routed_scaling_factor) - return topk_values, topk_indices - - def apply( - self, logits: torch.Tensor, e_score_correction_bias: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - topk_values, topk_indices = self.noaux_tc(logits, - e_score_correction_bias) - return topk_indices.to(torch.int32), topk_values.to(torch.float32) + k_b_proj, v_b_proj = kv_b_proj.split([ + local_num_heads * local_qk_nope_head_dim, + local_num_heads * local_v_head_dim + ], + dim=0) + k_b_proj = k_b_proj.view( + [local_num_heads, local_qk_nope_head_dim, -1]) + v_b_proj = v_b_proj.view([local_num_heads, local_v_head_dim, -1]) + return k_b_proj, v_b_proj -class DeepseekV3Gate(DeepSeekV3MoeRoutingMethod): + is_lite = self.config.q_lora_rank is None + num_heads = self.config.num_attention_heads + qk_nope_head_dim = self.config.qk_nope_head_dim + v_head_dim = self.config.v_head_dim + kv_lora_rank = self.config.kv_lora_rank - def __init__( - self, - hidden_size: int, - num_experts: int, - top_k: int, - n_group: int, - topk_group: int, - routed_scaling_factor: float, - dtype: Optional[torch.dtype] = None, - fuse_routing_kernel: bool = True, - apply_routing: bool = False, - moe_backend: str = 'CUTLASS', - ): - super().__init__(top_k=top_k) - self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), - dtype=dtype), - requires_grad=False) - self.moe_backend = moe_backend - if moe_backend == 'TRTLLM': - bias_dtype = torch.bfloat16 - else: - bias_dtype = torch.float32 + tp_rank = self.model_config.mapping.tp_rank + tp_size = self.model_config.mapping.tp_size - self.e_score_correction_bias = nn.Parameter(torch.empty( - (num_experts), dtype=bias_dtype), - requires_grad=False) + params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} + all_named_modules = dict(self.model.named_modules()) - assert not apply_routing, "DeepseekV3Gate routing is called inside MoE" + for name, module in tqdm(all_named_modules.items(), + desc="Loading weights"): + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: + names = name.split('.') + parent_module_name = '.'.join(names[:-1]) + if "model.layers" in name and int( + names[2]) >= self.config.num_hidden_layers: + mtp_layer_idx = int( + names[2]) - self.config.num_hidden_layers + names[2] = str(mtp_layer_idx % + self.config.num_nextn_predict_layers + + self.config.num_hidden_layers) + name = '.'.join(names) + if names[-1] == "kv_b_proj": + # TODO: remove weight_dequant after enabling fp8_bmm + dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( + names[-1]) + if dequant_kv_b_proj: + kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( + name) + else: + kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans( + name, is_scale=False) + module.weight.data.copy_( + kv_b_proj.reshape(module.weight.shape)) - # TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl. - # To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl. - # This is a temporary hack that should be refactored later. - self.routing_impl = Deepseekv3RoutingImpl( - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - routed_scaling_factor=routed_scaling_factor, - is_fused=fuse_routing_kernel) + attn_module = all_named_modules[parent_module_name] + _, v_b_proj = split_kv_b_proj(module.weight.data, + is_scale=False) + attn_module.v_b_proj = nn.Parameter(v_b_proj, + requires_grad=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states, - self.weight.t(), - bias=None, - out_dtype=torch.float32) - return logits + attn_module.k_b_proj_trans.data.copy_( + k_b_proj_trans.reshape( + attn_module.k_b_proj_trans.shape)) - def load_weights(self, weights: List[Dict]): - assert len(weights) == 1 + if getattr(module, "weight_scale", + None) is not None and not dequant_kv_b_proj: + kv_b_proj_scale, k_b_proj_trans_scale = load_kv_b_proj_and_k_b_proj_trans( + name, is_scale=True) + module.weight_scale.copy_( + kv_b_proj_scale.reshape(module.weight_scale.shape)) + attn_module.k_b_proj_trans_scale.copy_( + k_b_proj_trans_scale.reshape( + attn_module.k_b_proj_trans_scale.shape)) - self.weight.copy_(weights[0]["weight"][:]) + _, v_b_proj_scale = split_kv_b_proj( + module.weight_scale.data, is_scale=True) + attn_module.v_b_proj_scale = nn.Parameter( + v_b_proj_scale, requires_grad=False) - self.e_score_correction_bias.copy_( - weights[0]["e_score_correction_bias"][:].to( - self.e_score_correction_bias.dtype)) + if attn_module.k_b_proj_trans_dequant is not None: + attn_module.k_b_proj_trans_dequant.data.copy_( + weight_dequant( + k_b_proj_trans.view( + -1, k_b_proj_trans.shape[-1]).cuda(), + k_b_proj_trans_scale.view( + -1, + k_b_proj_trans_scale.shape[-1]).cuda(), + ).view( + *attn_module.k_b_proj_trans_dequant.shape). + to(attn_module.k_b_proj_trans_dequant.dtype)) + if attn_module.v_b_proj_dequant is not None: + attn_module.v_b_proj_dequant.data.copy_( + weight_dequant( + v_b_proj.view(-1, + v_b_proj.shape[-1]).cuda(), + v_b_proj_scale.view( + -1, v_b_proj_scale.shape[-1]).cuda(), + ).view(*attn_module.v_b_proj_dequant.shape).to( + attn_module.v_b_proj_dequant.dtype)) + elif names[-1] == "kv_a_proj_with_mqa": + fused_a = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] + if not is_lite: + q_a_proj = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] + fused_a = torch.cat([q_a_proj, fused_a], dim=0) - def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # topk routing - return self.routing_impl.apply(logits, self.e_score_correction_bias) + if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: + fused_a_scale = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] + if not is_lite: + q_a_proj_scale = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] + fused_a_scale = torch.cat( + [q_a_proj_scale, fused_a_scale], dim=0) - @property - def routing_method(self) -> DeepSeekV3MoeRoutingMethod: - return self + module.weight_scale.data.copy_(fused_a_scale) - def get_experts_per_token(self): - return self.routing_impl.top_k + module.weight.data.copy_(fused_a) + elif names[-1] in params_map: + module_weights = [] + for new_name in params_map[names[-1]]: + module_weights.append( + filter_weights('.'.join(names[:-1] + [new_name]), + weights)) + module.load_weights(weights=module_weights) + elif names[-1] == "experts": + module_weights = filter_weights(name, weights) + module_weights = rename_moe_weight(module_weights, { + "down_proj": "w2", + "up_proj": "w3", + "gate_proj": "w1", + }) + module.load_weights(weights=[module_weights]) + elif names[-1] == "self_attn": + continue + elif names[-1] == "next_layer_layernorm": + continue + else: + module_weights = filter_weights(name, weights) + if hasattr(module, 'load_weights'): + module.load_weights(weights=[module_weights]) + else: + for n, p in module.named_parameters(): + p.data.copy_(module_weights[n][:]) + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and is_sm_100f() and hasattr( + module, "weight_scale"): + weight, weight_scale = resmooth_to_fp8_e8m0( + module.weight, module.weight_scale) + transfromed_scale = transform_sf_into_required_layout( + weight_scale, + mn=weight.shape[0], + k=weight.shape[1], + recipe=(1, 128, 128), + is_sfa=False) + module.weight = nn.Parameter(weight, requires_grad=False) + module.weight_scale = nn.Parameter(transfromed_scale, + requires_grad=False) + if not self.is_draft_model: + for idx, layer in enumerate( + self.model.model.layers[:self.config.num_hidden_layers]): + if idx == self.config.num_hidden_layers - 1: + layer.next_layer_layernorm = self.model.model.norm + else: + layer.next_layer_layernorm = self.model.model.layers[ + idx + 1].input_layernorm -class Deepseekv3MoE(nn.Module): - def __init__(self, - *, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - shared_expert_intermediate_size: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig(), - override_quant_config: Optional[QuantConfig] = None, - layer_idx: Optional[int] = None): - from ..distributed import AllReduce +class DeepseekV3MTPHead(nn.Module): + def __init__(self, model_config: ModelConfig[PretrainedConfig]): super().__init__() config = model_config.pretrained_config - self.top_k = top_k - self.use_dp = model_config.mapping.enable_attention_dp - self.gate = DeepseekV3Gate( - hidden_size, - num_experts, - top_k=top_k, - n_group=config.n_group, - topk_group=config.topk_group, - routed_scaling_factor=config.routed_scaling_factor, - dtype=dtype, - fuse_routing_kernel=True, - apply_routing=False, - moe_backend=model_config.moe_backend) - self.experts = create_moe( - num_experts=num_experts, - routing_method=self.gate.routing_method, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, - reduce_results= - False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. - model_config=model_config, - override_quant_config=override_quant_config, - aux_stream_dict=aux_stream_dict, - layer_idx=layer_idx, - # DS-R1 W4A8 is only supported through custom quantization script from - # examples/quantization/quantize_mixed_precision_moe.py - weight_loading_mode=( - MoEWeightLoadingMode.W4A8_CUSTOM - if self._get_experts_quant_config( - model_config, - layer_idx).layer_quant_mode.is_int4_weight_only_per_group() - else MoEWeightLoadingMode.VANILLA), - ) + self.model_config = model_config - self.mapping = model_config.mapping + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + self.mapping_lm_head_tp = create_lm_head_tp_mapping( + self.model_config.mapping) + else: + self.mapping_lm_head_tp = self.model_config.mapping - # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) - block_size = 1 - if model_config.quant_config and model_config.quant_config.group_size is not None: - block_size = model_config.quant_config.group_size + @torch.compile(options={"max-autotune": True}) + def get_last_token_states(self, hidden_states, attn_metadata): + last_tokens = torch.cumsum( + attn_metadata.seq_lens_cuda, + dim=0, + dtype=torch.long, + ) - 1 + return hidden_states[last_tokens] - shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( - shared_expert_intermediate_size, block_size) + def forward(self, + hidden_states: torch.Tensor, + lm_head: Linear, + attn_metadata: AttentionMetadata, + return_context_logits: bool = False) -> torch.Tensor: + if not return_context_logits: + if attn_metadata is not None: + hidden_states = self.get_last_token_states( + hidden_states, attn_metadata) + else: + hidden_states = hidden_states[-1].unsqueeze(0) - self.shared_experts = GatedMLP( - hidden_size=hidden_size, - intermediate_size=shared_expert_intermediate_size, - bias=False, - dtype=dtype, - config=model_config, - overridden_tp_size=shared_tp_size, - reduce_output=False) + enable_attention_dp = self.model_config.mapping.enable_attention_dp + enable_lm_head_tp_in_adp = self.model_config.mapping.enable_lm_head_tp_in_adp - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy) - self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] - self.event_dict = { - key: torch.cuda.Event() - for key in [EventType.Main, EventType.MoeShared] - } + # Add pre-lm gather logic + if enable_lm_head_tp_in_adp: + # ADP + LM TP mode: perform All-Gather before LM_head + hidden_states = allgather(hidden_states, + self.mapping_lm_head_tp, + dim=0) - def _compute_shared_expert_tp_size(self, intermediate_size: int, - block_size: int) -> int: - """ - In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size. - For example, when the intermediate_size is 2048 and block scaling size is 128, - TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16. + # Temporarily disable gather_output when not in ADP mode or (in ADP mode and LM TP is enabled) + if not enable_attention_dp or enable_lm_head_tp_in_adp: + lm_head.gather_output = False + logits = lm_head(hidden_states, is_spec_decoding_head=True) + if not enable_attention_dp or enable_lm_head_tp_in_adp: + lm_head.gather_output = True + return logits - Args: - intermediate_size (int): MLP intermediate size. - block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, - it's 128. For NVFP4, it's 16. - Returns: - int: The computed tp_size. - """ +class DeepseekV3Linear(Linear): + """ + A wrapper around Linear because we may optionally use min-latency kernels depending on input shapes. + """ - assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + mapping: Optional[Mapping] = None, + tensor_parallel_mode: Optional[TensorParallelMode] = None, + gather_output: bool = False, # COLUMN parallel only + quant_config: Optional[QuantConfig] = None, + weights_loading_config: Optional[WeightsLoadingConfig] = None, + reduce_output: bool = True, # ROW parallel only + skip_create_weights_in_init: bool = False, + use_custom_cublas_mm: bool = False, + lora: Optional[LoraLayer] = None, + ): + super().__init__( + in_features, + out_features, + bias, + dtype, + mapping, + tensor_parallel_mode, + gather_output, + quant_config, + weights_loading_config, + reduce_output, + skip_create_weights_in_init, + use_custom_cublas_mm, + lora, + ) - shared_output_scale = None - # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128. - if self.use_dp: - # If using attention DP, the shared experts also use DP instead of TP. - shared_tp_size = 1 + def apply_linear(self, + input, + bias, + lora_params: Optional[dict] | None = None, + layer_idx: Optional[int] | None = None): + num_tokens = input.shape[0] + if (not self.has_any_quant and 1 <= num_tokens <= 16 + and get_sm_version() != 120): + output = torch.ops.trtllm.dsv3_fused_a_gemm_op( + input, self.weight.t(), bias, None) else: - # Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16. - # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes. - shared_tp_size = math.gcd( - intermediate_size // block_size, - self.mapping.tp_size, - ) - # If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce. - if shared_tp_size != self.mapping.tp_size: - shared_output_scale = shared_tp_size / self.mapping.tp_size - - return shared_tp_size, shared_output_scale + output = super().apply_linear(input, bias, lora_params, layer_idx) + return output - @staticmethod - def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig: - if getattr(model_config, "quant_config_dict", None) is None: - return model_config.quant_config - return model_config.quant_config_dict.get( - f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config) - def compute_routed_output(self, hidden_states, hidden_states_fp4, - all_rank_num_tokens, do_finalize): - # max-throughput - use_dp_padding = False - if self.use_dp and self.mapping.tp_size > 1: - if isinstance(self.experts, TRTLLMGenFusedMoE): - hidden_states = allgather(hidden_states, - self.mapping, - dim=0, - sizes=all_rank_num_tokens) +class DeepseekV3Attention(MLA): - router_logits = self.gate(hidden_states) + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + aux_stream: Optional[torch.cuda.Stream] = None, + ): + config = model_config.pretrained_config + predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 + super().__init__(hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + qk_rope_head_dim=config.qk_rope_head_dim, + qk_nope_head_dim=config.qk_nope_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + v_head_dim=config.v_head_dim, + predicted_tokens_per_seq=predicted_tokens_per_seq, + max_position_embeddings=config.max_position_embeddings, + bias=False, + pos_embd_params=PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(config), + is_neox=False, + ), + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + aux_stream=aux_stream) + self.kv_a_proj_with_mqa = DeepseekV3Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim + + (self.q_lora_rank if not self.is_lite else 0), + bias=False, + dtype=config.torch_dtype, + quant_config=model_config.get_quant_config(), + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + use_custom_cublas_mm=True) - routed_output = self.experts( - hidden_states_fp4 - if hidden_states_fp4 is not None else hidden_states, - router_logits, - do_finalize=do_finalize, - output_dtype=hidden_states.dtype, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - ) - return routed_output +class Deepseekv3RoutingImpl(): - def forward( + def __init__( self, - hidden_states: torch.Tensor, - hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, - all_rank_num_tokens: Optional[list[int]] = None, - final_all_reduce_params: Optional[AllReduceParams] = None, - do_finalize: Optional[bool] = True, - ) -> torch.Tensor: - if not do_finalize: - assert not self.use_dp - - def _compute_shared_output(): - shared_output = self.shared_experts( - hidden_states_fp4 - if hidden_states_fp4 is not None else hidden_states) - if self.shared_output_scale is not None: - shared_output *= self.shared_output_scale - return shared_output + top_k: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + is_fused: bool = True, + ): + super().__init__() + self.top_k = top_k + self.topk_group = topk_group + self.n_group = n_group + self.routed_scaling_factor = routed_scaling_factor + self.is_fused = is_fused - def _compute_routed_output(): - routed_output = self.compute_routed_output(hidden_states, - hidden_states_fp4, - all_rank_num_tokens, - do_finalize) - return routed_output + @staticmethod + @torch.compile(options={"max-autotune": True}) + def get_scores(logits, e_score_correction_bias): + scores = F.sigmoid(logits) + scores_with_bias = scores + e_score_correction_bias + return scores, scores_with_bias - # NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames + def noaux_tc(self, logits, e_score_correction_bias): + n_group = self.n_group + scores, scores_with_bias = Deepseekv3RoutingImpl.get_scores( + logits, e_score_correction_bias) + scores_shape = list(scores_with_bias.shape) - routed_output, shared_output = maybe_execute_in_parallel( - _compute_routed_output, _compute_shared_output, - self.event_dict[EventType.Main], - self.event_dict[EventType.MoeShared], self.aux_stream) + if enable_llm_debug(): + has_nan = torch.isnan(scores_with_bias).any() + if has_nan: + warnings.warn( + "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation." + ) - if not do_finalize: - return [shared_output, *routed_output] + if not self.is_fused: + group_scores = torch.sum(torch.topk( + scores_with_bias.view(scores_shape[:-1] + + [n_group, scores_shape[-1] // n_group]), + k=2, + dim=-1, + largest=True, + sorted=True)[0], + dim=-1) + _, group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + largest=True, + sorted=True) + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(-1, group_idx, 1) + score_mask = group_mask.unsqueeze(-1).expand( + scores_shape[:-1] + + [n_group, scores_shape[-1] // n_group]).reshape(scores_shape) + scores_with_bias = scores_with_bias * score_mask + _, topk_idx = torch.topk(scores_with_bias, + k=self.top_k, + dim=-1, + largest=True, + sorted=True) + new_mask = torch.zeros_like(scores) + new_mask.scatter_(-1, topk_idx, 1) + scores = scores * new_mask + score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 + scores = scores / score_sum * \ + self.routed_scaling_factor + topk_values, topk_indices = torch.topk(scores, + k=self.top_k, + dim=-1, + largest=True) + return topk_values, topk_indices else: - if routed_output.dim() == 3: - assert shared_output.numel( - ) * self.top_k == routed_output.numel( - ), 'unmatched tensor shape' - final_hidden_states = moe_reduce_add_shared_output( - routed_output, shared_output) - else: - assert shared_output.size() == routed_output.size( - ), 'unmatched tensor shape' - final_hidden_states = shared_output + routed_output - - if not self.use_dp and self.mapping.tp_size > 1: - final_hidden_states = self.allreduce( - final_hidden_states, - all_reduce_params=final_all_reduce_params) - - return final_hidden_states + topk_values, topk_indices = torch.ops.trtllm.noaux_tc_op( + scores, scores_with_bias, n_group, self.topk_group, self.top_k, + self.routed_scaling_factor) + return topk_values, topk_indices + def apply( + self, logits: torch.Tensor, e_score_correction_bias: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + topk_values, topk_indices = self.noaux_tc(logits, + e_score_correction_bias) + return topk_indices.to(torch.int32), topk_values.to(torch.float32) -class DeepseekV3DecoderLayer(DecoderLayer): - def __init__(self, model_config: ModelConfig[PretrainedConfig], - layer_idx: int, aux_stream_dict: Dict[AuxStreamType, - torch.cuda.Stream]): - super().__init__() - self.model_config = model_config - self.config = model_config.pretrained_config - config = self.config +class DeepseekV3Gate(DeepSeekV3MoeRoutingMethod): - self.hidden_size = config.hidden_size - self.moe_intermediate_size = config.moe_intermediate_size - self.num_experts = config.n_routed_experts - self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + dtype: Optional[torch.dtype] = None, + fuse_routing_kernel: bool = True, + apply_routing: bool = False, + moe_backend: str = 'CUTLASS', + ): + super().__init__(top_k=top_k) + self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), + dtype=dtype), + requires_grad=False) + self.moe_backend = moe_backend + if moe_backend == 'TRTLLM': + bias_dtype = torch.bfloat16 + else: + bias_dtype = torch.float32 - self.mapping = model_config.mapping - mapping = self.mapping + self.e_score_correction_bias = nn.Parameter(torch.empty( + (num_experts), dtype=bias_dtype), + requires_grad=False) - self.self_attn = DeepseekV3Attention( - model_config, - layer_idx=layer_idx, - aux_stream=aux_stream_dict[AuxStreamType.Attention]) - self.enable_attention_dp = mapping.enable_attention_dp + assert not apply_routing, "DeepseekV3Gate routing is called inside MoE" - self.mlp_tp_size = mapping.tp_size - self.is_p2p_supported = can_access_peer(mapping) + # TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl. + # To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl. + # This is a temporary hack that should be refactored later. + self.routing_impl = Deepseekv3RoutingImpl( + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + is_fused=fuse_routing_kernel) - self.fusion_config = EagerFusionConfig() - self.enable_fusion = os.environ.get( - "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0" - self.enable_fusion &= not self.enable_attention_dp + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states, + self.weight.t(), + bias=None, + out_dtype=torch.float32) + return logits - # FIXME: incompatible with mixed quantization mode - quant_config = self._get_decoder_layer_quant_config( - model_config, layer_idx) - self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() - assert ( - quant_config.quant_algo - is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous" + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 - has_tp = mapping.has_tp() - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - dtype=config.torch_dtype) - self.moe_allreduce = MoEAllReduce(self.mapping) + self.weight.copy_(weights[0]["weight"][:]) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + self.e_score_correction_bias.copy_( + weights[0]["e_score_correction_bias"][:].to( + self.e_score_correction_bias.dtype)) - self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp - self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION + def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # topk routing + return self.routing_impl.apply(logits, self.e_score_correction_bias) - self.mlp = Deepseekv3MoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=self.moe_intermediate_size, - shared_expert_intermediate_size=self.moe_intermediate_size * - self.num_shared_experts, - dtype=config.torch_dtype, - model_config=model_config, - override_quant_config=quant_config, - aux_stream_dict=aux_stream_dict, - layer_idx=layer_idx) - else: - block_size = 1 - if quant_config and quant_config.group_size is not None: - block_size = quant_config.group_size - self.mlp_tp_size = self._compute_mlp_tp_size( - config.intermediate_size, block_size) + @property + def routing_method(self) -> DeepSeekV3MoeRoutingMethod: + return self - has_mlp_tp = self.mlp_tp_size > 1 - self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 - self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp + def get_experts_per_token(self): + return self.routing_impl.top_k - self.mlp = GatedMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - bias=False, - dtype=config.torch_dtype, - config=model_config, - overridden_tp_size=self.mlp_tp_size, - reduce_output=True) - self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) +class Deepseekv3MoE(nn.Module): - self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION - or self.fusion_config.PRE_MLP_FUSION - or self.mapping.tp_size == 1 - or self.enable_attention_dp) + def __init__(self, + *, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + shared_expert_intermediate_size: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + dtype: Optional[torch.dtype] = None, + model_config: ModelConfig = ModelConfig(), + override_quant_config: Optional[QuantConfig] = None, + layer_idx: Optional[int] = None): + from ..distributed import AllReduce - self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.layer_idx = layer_idx - self.next_layer_layernorm: RMSNorm = None + super().__init__() + config = model_config.pretrained_config + self.top_k = top_k + self.use_dp = model_config.mapping.enable_attention_dp + self.gate = DeepseekV3Gate( + hidden_size, + num_experts, + top_k=top_k, + n_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + dtype=dtype, + fuse_routing_kernel=True, + apply_routing=False, + moe_backend=model_config.moe_backend) + self.experts = create_moe( + num_experts=num_experts, + routing_method=self.gate.routing_method, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results= + False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. + model_config=model_config, + override_quant_config=override_quant_config, + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx, + # DS-R1 W4A8 is only supported through custom quantization script from + # examples/quantization/quantize_mixed_precision_moe.py + weight_loading_mode=( + MoEWeightLoadingMode.W4A8_CUSTOM + if self._get_experts_quant_config( + model_config, + layer_idx).layer_quant_mode.is_int4_weight_only_per_group() + else MoEWeightLoadingMode.VANILLA), + ) - def _get_decoder_layer_quant_config( - self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): - """ - The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM - moe_backend only supports fp8/fp4 quantization, we need to override - the quant_config for the MTP layer. - """ - quant_config = model_config.quant_config + self.mapping = model_config.mapping - layer_name = f"model.layers.{layer_idx}" - if quant_config.is_module_excluded_from_quantization(layer_name): - return QuantConfig( - quant_algo=None, - kv_cache_quant_algo=quant_config.kv_cache_quant_algo, - ) - else: - return model_config.quant_config + # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) + block_size = 1 + if model_config.quant_config and model_config.quant_config.group_size is not None: + block_size = model_config.quant_config.group_size - def _compute_mlp_tp_size(self, intermediate_size: int, - block_size: int) -> int: + shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( + shared_expert_intermediate_size, block_size) + + self.shared_experts = GatedMLP( + hidden_size=hidden_size, + intermediate_size=shared_expert_intermediate_size, + bias=False, + dtype=dtype, + config=model_config, + overridden_tp_size=shared_tp_size, + reduce_output=False) + + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + + def _compute_shared_expert_tp_size(self, intermediate_size: int, + block_size: int) -> int: """ - For DeepSeek‑R1, MLP TP size is limited by intermediate_size // block_size - and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce. + In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size. + For example, when the intermediate_size is 2048 and block scaling size is 128, + TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16. Args: intermediate_size (int): MLP intermediate size. @@ -779,308 +822,335 @@ def _compute_mlp_tp_size(self, intermediate_size: int, """ assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." - if self.enable_attention_dp: - # If using attention DP, the MLP also uses DP instead of TP. - mlp_tp_size = 1 + + shared_output_scale = None + # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128. + if self.use_dp: + # If using attention DP, the shared experts also use DP instead of TP. + shared_tp_size = 1 else: - # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes. - tp = math.gcd( + # Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16. + # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes. + shared_tp_size = math.gcd( intermediate_size // block_size, self.mapping.tp_size, ) + # If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce. + if shared_tp_size != self.mapping.tp_size: + shared_output_scale = shared_tp_size / self.mapping.tp_size - if tp > self.mapping.gpus_per_node: - mlp_tp_size = math.gcd( - tp, - self.mapping.gpus_per_node, - ) # Avoid costly inter-node TP - else: - mlp_tp_size = tp - return mlp_tp_size + return shared_tp_size, shared_output_scale - def forward( - self, - position_ids: torch.IntTensor, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: torch.Tensor, - spec_metadata: Optional[SpecMetadata] = None, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states = self.self_attn( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - all_reduce_params=AllReduceParams( - enable_allreduce=not (self.disable_attn_allreduce)), - **kwargs, + @staticmethod + def _get_experts_quant_config(model_config, layer_idx: int) -> QuantConfig: + if getattr(model_config, "quant_config_dict", None) is None: + return model_config.quant_config + return model_config.quant_config_dict.get( + f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config) + + def compute_routed_output(self, hidden_states, hidden_states_fp4, + all_rank_num_tokens, do_finalize): + # max-throughput + use_dp_padding = False + if self.use_dp and self.mapping.tp_size > 1: + if isinstance(self.experts, TRTLLMGenFusedMoE): + hidden_states = allgather(hidden_states, + self.mapping, + dim=0, + sizes=all_rank_num_tokens) + + router_logits = self.gate(hidden_states) + + routed_output = self.experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states, + router_logits, + do_finalize=do_finalize, + output_dtype=hidden_states.dtype, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, ) - if isinstance(self.mlp, Deepseekv3MoE): - if spec_metadata is not None and spec_metadata.is_layer_capture( - self.layer_idx): - self.fusion_config.POST_MOE_FUSION = False - return self.forward_MoE( - hidden_states=hidden_states, - attn_metadata=attn_metadata, - residual=residual, - spec_metadata=spec_metadata, - ) - else: - if spec_metadata is not None and spec_metadata.is_layer_capture( - self.layer_idx): - self.fusion_config.POST_MLP_FUSION = False - assert isinstance(self.mlp, GatedMLP) - return self.forward_mlp( - hidden_states=hidden_states, - residual=residual, - spec_metadata=spec_metadata, - ) - def forward_MoE( + return routed_output + + def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: torch.Tensor, - spec_metadata: Optional[SpecMetadata] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, + all_rank_num_tokens: Optional[list[int]] = None, + final_all_reduce_params: Optional[AllReduceParams] = None, + do_finalize: Optional[bool] = True, + ) -> torch.Tensor: + if not do_finalize: + assert not self.use_dp - def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): - return self.mlp( - hidden_states, - hidden_states_fp4, - all_rank_num_tokens=attn_metadata.all_rank_num_tokens, - final_all_reduce_params=AllReduceParams( - enable_allreduce=not (self.fusion_config.POST_MOE_FUSION - or self.mapping.tp_size == 1)), - do_finalize=do_finalize, - ) + def _compute_shared_output(): + shared_output = self.shared_experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states) + if self.shared_output_scale is not None: + shared_output *= self.shared_output_scale + return shared_output - if self.fusion_config.PRE_MOE_FUSION: - # moe_backend can be either CUTLASS or TRTLLM here - # TODO: unify the two min-latency MoE backends by enabling quant fusion - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - eps=self.post_attention_layernorm.variance_epsilon, - trigger_completion_at_end=False, - )) - else: - # No fusion - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + def _compute_routed_output(): + routed_output = self.compute_routed_output(hidden_states, + hidden_states_fp4, + all_rank_num_tokens, + do_finalize) + return routed_output - # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now - do_finalize = self.mapping.is_multi_node() or ( - not (hidden_states.shape[0] <= self.moe_allreduce.max_token - and self.fusion_config.POST_MOE_FUSION - and self.model_config.moe_backend == "TRTLLM" - and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)) + # NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames - hidden_states = _run_MoE(hidden_states, - hidden_states_fp4=None, - do_finalize=do_finalize) + routed_output, shared_output = maybe_execute_in_parallel( + _compute_routed_output, _compute_shared_output, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], self.aux_stream) - if self.fusion_config.POST_MOE_FUSION: - if do_finalize: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - trigger_completion_at_end=False, - )) + if not do_finalize: + return [shared_output, *routed_output] + else: + if routed_output.dim() == 3: + assert shared_output.numel( + ) * self.top_k == routed_output.numel( + ), 'unmatched tensor shape' + final_hidden_states = moe_reduce_add_shared_output( + routed_output, shared_output) else: - assert len( - hidden_states) == 4, "hidden_states must have 4 elements" + assert shared_output.size() == routed_output.size( + ), 'unmatched tensor shape' + final_hidden_states = shared_output + routed_output - shared_output = hidden_states[0] - fc2_output = hidden_states[1] - expert_scale_factor = hidden_states[2] - expanded_idx_to_permuted_idx = hidden_states[3] + if not self.use_dp and self.mapping.tp_size > 1: + final_hidden_states = self.allreduce( + final_hidden_states, + all_reduce_params=final_all_reduce_params) - moe_all_reduce_params = MoEAllReduceParams( - expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, - expert_scale_factor=expert_scale_factor, - shared_expert_output=shared_output, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - is_cutlass_min_latency=False, - ) - hidden_states, residual = self.moe_allreduce( - fc2_output, all_reduce_params=moe_all_reduce_params) - else: - if spec_metadata is not None and spec_metadata.is_layer_capture( - self.layer_idx): - spec_metadata.maybe_capture_hidden_states( - self.layer_idx, hidden_states, residual) - if self.next_layer_layernorm is not None: - hidden_states, residual = self.next_layer_layernorm( - hidden_states, residual) + return final_hidden_states - return hidden_states, residual - def forward_mlp( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - spec_metadata: Optional[SpecMetadata] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: +class DeepseekV3DecoderLayer(DecoderLayer): - if self.fusion_config.PRE_MLP_FUSION: - act_fp4, act_sf, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - scale=self.mlp.gate_up_proj.input_scale, - eps=self.post_attention_layernorm.variance_epsilon, - ), - ) - hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) - else: - # No fusion - # We need to add twoshot allreduce here to avoid modifying MLA logic - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + is_separate_draft_engine: bool = False): + super().__init__() + self.model_config = model_config + self.config = model_config.pretrained_config + config = self.config - hidden_states = self.mlp( - hidden_states, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), - ) + self.hidden_size = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok - if self.fusion_config.POST_MLP_FUSION: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - ), - ) - else: - if spec_metadata is not None and spec_metadata.is_layer_capture( - self.layer_idx): - spec_metadata.maybe_capture_hidden_states( - self.layer_idx, hidden_states, residual) - if self.next_layer_layernorm is not None: - hidden_states, residual = self.next_layer_layernorm( - hidden_states, residual) + self.mapping = model_config.mapping + mapping = self.mapping + layer_idx_for_attention = layer_idx + if is_separate_draft_engine: + #KVCacheManager only support 1 layer for separate draft engine + layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers - return hidden_states, residual + self.self_attn = DeepseekV3Attention( + model_config, + layer_idx=layer_idx_for_attention, + aux_stream=aux_stream_dict[AuxStreamType.Attention]) + self.enable_attention_dp = mapping.enable_attention_dp + self.mlp_tp_size = mapping.tp_size + self.is_p2p_supported = can_access_peer(mapping) -class DeepseekV3MTP(DeepseekV3DecoderLayer): + self.fusion_config = EagerFusionConfig() + self.enable_fusion = os.environ.get( + "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0" + self.enable_fusion &= not self.enable_attention_dp - def __init__(self, model_config: ModelConfig[PretrainedConfig], - layer_idx: int, aux_stream_dict: Dict[AuxStreamType, - torch.cuda.Stream]): - super().__init__(model_config, layer_idx, aux_stream_dict) - config = model_config.pretrained_config - self.hidden_dim = config.hidden_size - self.moe_intermediate_size = config.moe_intermediate_size - self.num_experts = config.n_routed_experts - self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok + # FIXME: incompatible with mixed quantization mode + quant_config = self._get_decoder_layer_quant_config( + model_config, layer_idx) + self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() + assert ( + quant_config.quant_algo + is not QuantAlgo.MIXED_PRECISION), "MIXED_PRECISION is ambiguous" - self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] - self.event_dict = { - key: torch.cuda.Event() - for key in [EventType.Main, EventType.MoeShared] - } + has_tp = mapping.has_tp() + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) - self.enorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): - self.hnorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - if model_config.mapping.enable_attention_dp: - self.eh_proj = Linear( - config.hidden_size * 2, - config.hidden_size, - bias=False, + self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION + + self.mlp = Deepseekv3MoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.moe_intermediate_size, + shared_expert_intermediate_size=self.moe_intermediate_size * + self.num_shared_experts, dtype=config.torch_dtype, - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - ) + model_config=model_config, + override_quant_config=quant_config, + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx) else: - self.eh_proj = Linear( - config.hidden_size * 2, - config.hidden_size, - bias=False, - dtype=config.torch_dtype, - tensor_parallel_mode=TensorParallelMode.ROW, - mapping=model_config.mapping, - reduce_output=True, - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - ) - - self.shared_head = DeepseekV3MTPHead(model_config) + block_size = 1 + if quant_config and quant_config.group_size is not None: + block_size = quant_config.group_size + self.mlp_tp_size = self._compute_mlp_tp_size( + config.intermediate_size, block_size) - def forward( - self, - input_ids: torch.IntTensor, - position_ids: torch.IntTensor, - hidden_states: torch.Tensor, - embed_tokens: Embedding, - attn_metadata: AttentionMetadata, - all_rank_num_tokens: Optional[List[int]] = None, - **kwargs, - ) -> torch.Tensor: + has_mlp_tp = self.mlp_tp_size > 1 + self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 + self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp - def norm_embeds(): - return self.enorm(embed_tokens(input_ids)) #emdedding + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + dtype=config.torch_dtype, + config=model_config, + overridden_tp_size=self.mlp_tp_size, + reduce_output=True) - def norm_hidden(): - return self.hnorm(hidden_states) + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) - inputs_embeds, hidden_states = maybe_execute_in_parallel( - norm_embeds, - norm_hidden, - self.event_dict[EventType.Main], - self.event_dict[EventType.MoeShared], - self.aux_stream, - ) - hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) - # Split hidden_states columnwise based on TP - tp_size = self.model_config.mapping.tp_size - tp_rank = self.model_config.mapping.tp_rank + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.fusion_config.PRE_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) - if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): - hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] - hidden_states = self.eh_proj(hidden_states) + self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.layer_idx = layer_idx + self.next_layer_layernorm: RMSNorm = None - # Input layer norm - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + def _get_decoder_layer_quant_config( + self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): + """ + The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM + moe_backend only supports fp8/fp4 quantization, we need to override + the quant_config for the MTP layer. + """ + quant_config = model_config.quant_config - # Self Attention - hidden_states = self.self_attn( - position_ids=position_ids, - hidden_states=hidden_states, + layer_name = f"model.layers.{layer_idx}" + if quant_config.is_module_excluded_from_quantization(layer_name): + return QuantConfig( + quant_algo=None, + kv_cache_quant_algo=quant_config.kv_cache_quant_algo, + ) + else: + return model_config.quant_config + + def _compute_mlp_tp_size(self, intermediate_size: int, + block_size: int) -> int: + """ + For DeepSeek‑R1, MLP TP size is limited by intermediate_size // block_size + and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce. + + Args: + intermediate_size (int): MLP intermediate size. + block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, + it's 128. For NVFP4, it's 16. + + Returns: + int: The computed tp_size. + """ + + assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + if self.enable_attention_dp: + # If using attention DP, the MLP also uses DP instead of TP. + mlp_tp_size = 1 + else: + # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes. + tp = math.gcd( + intermediate_size // block_size, + self.mapping.tp_size, + ) + + if tp > self.mapping.gpus_per_node: + mlp_tp_size = math.gcd( + tp, + self.mapping.gpus_per_node, + ) # Avoid costly inter-node TP + else: + mlp_tp_size = tp + return mlp_tp_size + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, attn_metadata=attn_metadata, all_reduce_params=AllReduceParams( enable_allreduce=not (self.disable_attn_allreduce)), **kwargs, ) + if isinstance(self.mlp, Deepseekv3MoE): + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False + return self.forward_MoE( + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + spec_metadata=spec_metadata, + ) + else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False + assert isinstance(self.mlp, GatedMLP) + return self.forward_mlp( + hidden_states=hidden_states, + residual=residual, + spec_metadata=spec_metadata, + ) + + def forward_MoE( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): + return self.mlp( + hidden_states, + hidden_states_fp4, + all_rank_num_tokens=attn_metadata.all_rank_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + do_finalize=do_finalize, + ) - # MTP Layer Must have sparse MOE if self.fusion_config.PRE_MOE_FUSION: + # moe_backend can be either CUTLASS or TRTLLM here + # TODO: unify the two min-latency MoE backends by enabling quant fusion hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( @@ -1088,78 +1158,300 @@ def norm_hidden(): residual=residual, norm_weight=self.post_attention_layernorm.weight, eps=self.post_attention_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now + do_finalize = self.mapping.is_multi_node() or ( + not (hidden_states.shape[0] <= self.moe_allreduce.max_token + and self.fusion_config.POST_MOE_FUSION + and self.model_config.moe_backend == "TRTLLM" + and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)) + + hidden_states = _run_MoE(hidden_states, + hidden_states_fp4=None, + do_finalize=do_finalize) + + if self.fusion_config.POST_MOE_FUSION: + if do_finalize: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + assert len( + hidden_states) == 4, "hidden_states must have 4 elements" + + shared_output = hidden_states[0] + fc2_output = hidden_states[1] + expert_scale_factor = hidden_states[2] + expanded_idx_to_permuted_idx = hidden_states[3] + + moe_all_reduce_params = MoEAllReduceParams( + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + expert_scale_factor=expert_scale_factor, + shared_expert_output=shared_output, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + is_cutlass_min_latency=False, + ) + hidden_states, residual = self.moe_allreduce( + fc2_output, all_reduce_params=moe_all_reduce_params) + else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.fusion_config.PRE_MLP_FUSION: + act_fp4, act_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.gate_up_proj.input_scale, + eps=self.post_attention_layernorm.variance_epsilon, ), ) + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) else: + # No fusion + # We need to add twoshot allreduce here to avoid modifying MLA logic hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - # MoE hidden_states = self.mlp( hidden_states, - all_rank_num_tokens=all_rank_num_tokens, - final_all_reduce_params=AllReduceParams( - enable_allreduce=not (self.fusion_config.POST_MOE_FUSION - or self.mapping.tp_size == 1)), + final_all_reduce_params=AllReduceParams(enable_allreduce=not ( + self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), ) - if self.fusion_config.POST_MOE_FUSION: + if self.fusion_config.POST_MLP_FUSION: hidden_states, residual = self.allreduce( hidden_states, all_reduce_params=AllReduceParams( fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, residual=residual, - norm_weight=self.shared_head.norm.weight, - eps=self.shared_head.norm.variance_epsilon, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, ), ) else: - hidden_states, _ = self.shared_head.norm(hidden_states, residual) + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) - return hidden_states + return hidden_states, residual -class DeepseekV3Model(DecoderModel): +class DeepseekV3MTP(DeepseekV3DecoderLayer): - def __init__(self, model_config: ModelConfig[PretrainedConfig]): - super().__init__(model_config) + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + is_separate_draft_engine: bool = False): + super().__init__(model_config, layer_idx, aux_stream_dict, + is_separate_draft_engine) config = model_config.pretrained_config - self.vocab_size = config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] - self.aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - AuxStreamType.MoeShared: aux_stream_list[0], - AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], - AuxStreamType.MoeBalancer: aux_stream_list[2], - } + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok - self.embed_tokens = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.torch_dtype, - ) + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } - self.layers = nn.ModuleList([ - DeepseekV3DecoderLayer(model_config, layer_idx, - self.aux_stream_dict) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + self.enorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) - def forward( - self, - attn_metadata: AttentionMetadata, - input_ids: Optional[torch.IntTensor] = None, - position_ids: Optional[torch.IntTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[SpecMetadata] = None, - **kwargs, - ) -> torch.Tensor: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( + self.hnorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + if model_config.mapping.enable_attention_dp: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + else: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + tensor_parallel_mode=TensorParallelMode.ROW, + mapping=model_config.mapping, + reduce_output=True, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + + self.shared_head = DeepseekV3MTPHead(model_config) + + def forward( + self, + input_ids: torch.IntTensor, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + embed_tokens: Embedding, + attn_metadata: AttentionMetadata, + all_rank_num_tokens: Optional[List[int]] = None, + **kwargs, + ) -> torch.Tensor: + + def norm_embeds(): + return self.enorm(embed_tokens(input_ids)) #emdedding + + def norm_hidden(): + return self.hnorm(hidden_states) + + inputs_embeds, hidden_states = maybe_execute_in_parallel( + norm_embeds, + norm_hidden, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], + self.aux_stream, + ) + hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) + # Split hidden_states columnwise based on TP + tp_size = self.model_config.mapping.tp_size + tp_rank = self.model_config.mapping.tp_rank + + if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): + hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] + hidden_states = self.eh_proj(hidden_states) + + # Input layer norm + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.disable_attn_allreduce)), + **kwargs, + ) + + # MTP Layer Must have sparse MOE + if self.fusion_config.PRE_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + ), + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # MoE + hidden_states = self.mlp( + hidden_states, + all_rank_num_tokens=all_rank_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + ) + + if self.fusion_config.POST_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.shared_head.norm.weight, + eps=self.shared_head.norm.variance_epsilon, + ), + ) + else: + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + + return hidden_states + + +class DeepseekV3Model(DecoderModel): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model_config) + config = model_config.pretrained_config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + self.aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + AuxStreamType.MoeShared: aux_stream_list[0], + AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + AuxStreamType.MoeBalancer: aux_stream_list[2], + } + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + ) + + self.layers = nn.ModuleList([ + DeepseekV3DecoderLayer(model_config, layer_idx, + self.aux_stream_dict) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) @@ -1203,7 +1495,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config=model_config) self.model_nextn = 0 - if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp( + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model( ): model_nextn = model_config.spec_config.num_nextn_predict_layers ckpt_nextn = self.config.num_nextn_predict_layers @@ -1252,274 +1544,5 @@ def forward( **kwargs) def load_weights(self, weights: Dict): - - def rename_moe_weight(weights: Dict, rename_rules: Dict): - result = {} - for key, value in weights.items(): - new_key = key - for old, new in rename_rules.items(): - new_key = new_key.replace(old, new) - result[new_key] = value - return result - - ## Prepare weights for TP - def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - if len(v.shape) == 1: - return torch.chunk(v, tp_size)[idx].contiguous() - else: - return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() - - def split_matrix_tp(v, tensor_parallel, rank, dim): - return split(v, tensor_parallel, rank, dim=dim) - - def load_kv_b_proj_and_k_b_proj_trans(module_name: str, - is_scale: bool) -> torch.Tensor: - weight_name = "weight" if not is_scale else "weight_scale_inv" - local_qk_nope_head_dim = qk_nope_head_dim if not is_scale else qk_nope_head_dim // 128 - local_v_head_dim = v_head_dim if not is_scale else v_head_dim // 128 - local_kv_lora_rank = kv_lora_rank if not is_scale else kv_lora_rank // 128 - - kv_b_proj = weights[f"{module_name}.{weight_name}"][:].unflatten( - 0, - [ - num_heads, - local_qk_nope_head_dim + local_v_head_dim, - ], - ) - - if not self.model_config.mapping.enable_attention_dp: - kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) - k_nope_weight, v_weight = kv_b_proj.split( - [local_qk_nope_head_dim, local_v_head_dim], - dim=1, - ) - weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size - local_num_heads = num_heads // weight_divisor - - k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() - - kv_b_proj = torch.concat([ - k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, - local_kv_lora_rank), - v_weight.reshape(local_num_heads * local_v_head_dim, - local_kv_lora_rank) - ], - dim=0) - - return kv_b_proj, k_nope_weight_trans - - def load_kv_b_proj_and_k_b_proj_trans_dequant( - module_name: str) -> torch.Tensor: - weight_name = "weight" - local_qk_nope_head_dim = qk_nope_head_dim - local_v_head_dim = v_head_dim - local_kv_lora_rank = kv_lora_rank - - kv_b_proj = weights[f"{module_name}.{weight_name}"][:].cuda() - - weight_name = "weight_scale_inv" - kv_b_proj_scale = weights[f"{module_name}.{weight_name}"][:].cuda() - - kv_b_proj = weight_dequant(kv_b_proj, kv_b_proj_scale) - kv_b_proj = kv_b_proj.unflatten( - 0, - [ - num_heads, - local_qk_nope_head_dim + local_v_head_dim, - ], - ) - if not self.model_config.mapping.enable_attention_dp: - kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) - k_nope_weight, v_weight = kv_b_proj.split( - [local_qk_nope_head_dim, local_v_head_dim], - dim=1, - ) - weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size - local_num_heads = num_heads // weight_divisor - - k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() - - kv_b_proj = torch.concat([ - k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, - local_kv_lora_rank), - v_weight.reshape(local_num_heads * local_v_head_dim, - local_kv_lora_rank) - ], - dim=0) - - return kv_b_proj, k_nope_weight_trans - - def split_kv_b_proj(kv_b_proj: torch.Tensor, - is_scale: bool) -> torch.Tensor: - local_qk_nope_head_dim = qk_nope_head_dim if not is_scale else qk_nope_head_dim // 128 - local_v_head_dim = v_head_dim if not is_scale else v_head_dim // 128 - - weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size - local_num_heads = num_heads // weight_divisor - - k_b_proj, v_b_proj = kv_b_proj.split([ - local_num_heads * local_qk_nope_head_dim, - local_num_heads * local_v_head_dim - ], - dim=0) - k_b_proj = k_b_proj.view( - [local_num_heads, local_qk_nope_head_dim, -1]) - v_b_proj = v_b_proj.view([local_num_heads, local_v_head_dim, -1]) - - return k_b_proj, v_b_proj - - is_lite = self.config.q_lora_rank is None - num_heads = self.config.num_attention_heads - qk_nope_head_dim = self.config.qk_nope_head_dim - v_head_dim = self.config.v_head_dim - kv_lora_rank = self.config.kv_lora_rank - - tp_rank = self.model_config.mapping.tp_rank - tp_size = self.model_config.mapping.tp_size - - params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} - all_named_modules = dict(self.named_modules()) - - for name, module in tqdm(all_named_modules.items(), - desc="Loading weights"): - if len(module._parameters) <= 0 or name.startswith("draft_model"): - continue - else: - names = name.split('.') - parent_module_name = '.'.join(names[:-1]) - if "model.layers" in name and int( - names[2]) >= self.config.num_hidden_layers: - mtp_layer_idx = int( - names[2]) - self.config.num_hidden_layers - names[2] = str(mtp_layer_idx % - self.config.num_nextn_predict_layers + - self.config.num_hidden_layers) - name = '.'.join(names) - if names[-1] == "kv_b_proj": - # TODO: remove weight_dequant after enabling fp8_bmm - dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) - if dequant_kv_b_proj: - kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( - name) - else: - kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans( - name, is_scale=False) - module.weight.data.copy_( - kv_b_proj.reshape(module.weight.shape)) - - attn_module = all_named_modules[parent_module_name] - _, v_b_proj = split_kv_b_proj(module.weight.data, - is_scale=False) - attn_module.v_b_proj = nn.Parameter(v_b_proj, - requires_grad=False) - - attn_module.k_b_proj_trans.data.copy_( - k_b_proj_trans.reshape( - attn_module.k_b_proj_trans.shape)) - - if getattr(module, "weight_scale", - None) is not None and not dequant_kv_b_proj: - kv_b_proj_scale, k_b_proj_trans_scale = load_kv_b_proj_and_k_b_proj_trans( - name, is_scale=True) - module.weight_scale.copy_( - kv_b_proj_scale.reshape(module.weight_scale.shape)) - attn_module.k_b_proj_trans_scale.copy_( - k_b_proj_trans_scale.reshape( - attn_module.k_b_proj_trans_scale.shape)) - - _, v_b_proj_scale = split_kv_b_proj( - module.weight_scale.data, is_scale=True) - attn_module.v_b_proj_scale = nn.Parameter( - v_b_proj_scale, requires_grad=False) - - if attn_module.k_b_proj_trans_dequant is not None: - attn_module.k_b_proj_trans_dequant.data.copy_( - weight_dequant( - k_b_proj_trans.view( - -1, k_b_proj_trans.shape[-1]).cuda(), - k_b_proj_trans_scale.view( - -1, - k_b_proj_trans_scale.shape[-1]).cuda(), - ).view( - *attn_module.k_b_proj_trans_dequant.shape). - to(attn_module.k_b_proj_trans_dequant.dtype)) - if attn_module.v_b_proj_dequant is not None: - attn_module.v_b_proj_dequant.data.copy_( - weight_dequant( - v_b_proj.view(-1, - v_b_proj.shape[-1]).cuda(), - v_b_proj_scale.view( - -1, v_b_proj_scale.shape[-1]).cuda(), - ).view(*attn_module.v_b_proj_dequant.shape).to( - attn_module.v_b_proj_dequant.dtype)) - elif names[-1] == "kv_a_proj_with_mqa": - fused_a = weights[ - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] - if not is_lite: - q_a_proj = weights[ - f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] - fused_a = torch.cat([q_a_proj, fused_a], dim=0) - - if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: - fused_a_scale = weights[ - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] - if not is_lite: - q_a_proj_scale = weights[ - f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] - fused_a_scale = torch.cat( - [q_a_proj_scale, fused_a_scale], dim=0) - - module.weight_scale.data.copy_(fused_a_scale) - - module.weight.data.copy_(fused_a) - elif names[-1] in params_map: - module_weights = [] - for new_name in params_map[names[-1]]: - module_weights.append( - filter_weights('.'.join(names[:-1] + [new_name]), - weights)) - module.load_weights(weights=module_weights) - elif names[-1] == "experts": - module_weights = filter_weights(name, weights) - module_weights = rename_moe_weight(module_weights, { - "down_proj": "w2", - "up_proj": "w3", - "gate_proj": "w1", - }) - module.load_weights(weights=[module_weights]) - elif names[-1] == "self_attn": - continue - elif names[-1] == "next_layer_layernorm": - continue - else: - module_weights = filter_weights(name, weights) - if hasattr(module, 'load_weights'): - module.load_weights(weights=[module_weights]) - else: - for n, p in module.named_parameters(): - p.data.copy_(module_weights[n][:]) - - if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and is_sm_100f() and hasattr(module, "weight_scale"): - weight, weight_scale = resmooth_to_fp8_e8m0( - module.weight, module.weight_scale) - transfromed_scale = transform_sf_into_required_layout( - weight_scale, - mn=weight.shape[0], - k=weight.shape[1], - recipe=(1, 128, 128), - is_sfa=False) - module.weight = nn.Parameter(weight, requires_grad=False) - module.weight_scale = nn.Parameter(transfromed_scale, - requires_grad=False) - - for idx, layer in enumerate( - self.model.layers[:self.config.num_hidden_layers]): - if idx == self.config.num_hidden_layers - 1: - layer.next_layer_layernorm = self.model.norm - else: - layer.next_layer_layernorm = self.model.layers[ - idx + 1].input_layernorm + weight_loader = DeepseekV3WeightLoader(self) + weight_loader.load_weights(weights) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index b60dd240336..0b92b24df46 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -1,4 +1,4 @@ -from typing import Dict, Generic, Optional, Tuple +from typing import Dict, Generic, List, Optional, Tuple import torch from torch import nn @@ -18,6 +18,7 @@ from ..modules.rms_norm import RMSNorm from ..pyexecutor.guided_decoder import CapturableGuidedDecoder from ..speculative import SpecMetadata, get_spec_worker +from ..utils import AuxStreamType from .checkpoints.base_weight_mapper import BaseWeightMapper from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel, register_auto_model) @@ -342,8 +343,8 @@ def __init__( from .modeling_deepseekv3 import DeepseekV3MTP spec_dec_mode = model_config.spec_config.spec_dec_mode - assert spec_dec_mode.is_mtp() - mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle( + assert spec_dec_mode.is_mtp_one_model() + mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle_one_model( ) else model_config.spec_config.num_nextn_predict_layers moe_load_balancer_set_repeated_for_next_layer( @@ -358,16 +359,127 @@ def __init__( self.embed_tokens = model.embed_tokens +class MTPDraftModel(nn.Module): + + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__() + # Import here to avoid circular import + from .modeling_deepseekv3 import DeepseekV3MTP + + mtp_layer = DeepseekV3MTP(model_config, + layer_idx, + aux_stream_dict, + is_separate_draft_engine=True) + setattr(self, f"layers.{layer_idx}", mtp_layer) + self.layers = mtp_layer + self.layer_idx = layer_idx + self.config = model_config.pretrained_config + self.embed_tokens = Embedding( + self.config.vocab_size, + self.config.hidden_size, + dtype=self.config.torch_dtype, + ) + + def __repr__(self): + """Custom string representation to display layer index""" + return f"(layers): ({self.layer_idx}): {repr(self.layers)}" + + def forward( + self, + input_ids: torch.IntTensor, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + all_rank_num_tokens: Optional[List[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.layers( + input_ids, + position_ids, + hidden_states, + embed_tokens=self.embed_tokens, + attn_metadata=attn_metadata, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=all_rank_max_num_tokens, + ) + + return hidden_states + + +@register_auto_model("MTPDraftModelForCausalLM") +class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel, + PretrainedConfig]): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + self.model_config = model_config + aux_stream_list = [torch.cuda.Stream() for _ in range(2)] + self.aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + AuxStreamType.MoeShared: aux_stream_list[0], + AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + } + super().__init__( + MTPDraftModel(self.model_config, + self.model_config.pretrained_config.num_hidden_layers, + self.aux_stream_dict), + config=self.model_config, + hidden_size=self.model_config.pretrained_config.hidden_size, + vocab_size=self.model_config.pretrained_config.vocab_size) + + def load_weights(self, weights: Dict): + # Import here to avoid circular import + from .modeling_deepseekv3 import DeepseekV3WeightLoader + weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True) + weight_loader.load_weights(weights) + + def load_weights_from_target_model(self, + target_model: torch.nn.Module) -> None: + if self.model.embed_tokens is None: + self.model.embed_tokens = target_model.model.embed_tokens + self.lm_head = target_model.lm_head + + def forward(self, + attn_metadata: AttentionMetadata, + input_ids: torch.IntTensor = None, + position_ids: torch.IntTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_context_logits: bool = False, + spec_metadata: Optional[SpecMetadata] = None, + hidden_states: torch.Tensor = None, + **kwargs) -> torch.Tensor: + + hidden_states = spec_metadata.get_hidden_states() + output = self.model( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_rank_num_tokens=attn_metadata.all_rank_num_tokens, + all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens, + **kwargs) + return self.logits_processor.forward( + output, + self.lm_head, + attn_metadata, + return_context_logits, + ) + + def get_draft_model(model_config, draft_config, lm_head, model): assert getattr(model_config, 'spec_config', None) != None spec_dec_mode = model_config.spec_config.spec_dec_mode if spec_dec_mode.is_eagle3_one_model(): return Eagle3ForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) - elif spec_dec_mode.is_mtp(): + elif spec_dec_mode.is_mtp_one_model(): return MTPForCausalLM(model_config, model_config.pretrained_config.num_hidden_layers, lm_head, model) + elif spec_dec_mode.is_mtp_eagle(): + return MTPDraftModelForCausalLM(model_config) else: raise NotImplementedError( f"get_draft_model does not support speculative decoding mode {spec_dec_mode}." diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bd5b3aacf36..c5b453bda06 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -330,6 +330,9 @@ def drafting_loop_wrapper(model): is_draft_model=True, drafting_loop_wrapper=drafting_loop_wrapper, ) + # For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model + if spec_config.spec_dec_mode.is_mtp_eagle(): + draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1 draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER draft_model_engine.load_weights_from_target_model( model_engine.model) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 684bb9450f9..7fb7d9f0736 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -92,18 +92,18 @@ class Eagle3SpecMetadata(SpecMetadata): is_draft_model: bool = False is_first_draft: bool = False eagle3_resource_manager: Optional[Eagle3ResourceManager] = None + is_mtp_eagle: bool = False def __post_init__(self): if self.is_draft_model: self.layers_to_capture = (self.num_layers - 1, ) elif self.layers_to_capture is None: - if self.num_layers == 1: + if self.num_layers == 1 or self.is_mtp_eagle: self.layers_to_capture = (self.num_layers - 1, ) else: if self.num_layers <= 5: raise ValueError( "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, self.num_layers - 4) else: diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 4b3723e2ca0..271aee20d76 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -12,6 +12,7 @@ class SpeculativeDecodingMode(IntEnum): MTP = auto() MTP_EAGLE = auto() + MTP_EAGLE_ONE_MODEL = auto() EAGLE3 = auto() EAGLE3_ONE_MODEL = auto() NGRAM = auto() @@ -20,8 +21,11 @@ class SpeculativeDecodingMode(IntEnum): NONE = auto() AUTO = auto() - def is_mtp(self): - return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE + def is_mtp_one_model(self): + return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL + + def is_mtp_eagle_one_model(self): + return self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL def is_mtp_vanilla(self): return self == SpeculativeDecodingMode.MTP @@ -33,7 +37,7 @@ def is_eagle3(self): return self == SpeculativeDecodingMode.EAGLE3 def use_one_engine(self): - return self.is_mtp() or self.is_eagle3_one_model() + return self.is_eagle3_one_model() or self.is_mtp_one_model() def is_eagle3_one_model(self): return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL @@ -51,23 +55,24 @@ def is_draft_target(self): return self == SpeculativeDecodingMode.DRAFT_TARGET def without_logits(self): - return self.is_mtp() or self.is_eagle3_one_model() + return self.is_mtp_one_model() or self.is_eagle3_one_model() def needs_kv_cache_rewind(self): - return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram() + return self.is_mtp_one_model() or self.is_eagle3_one_model( + ) or self.is_ngram() def support_overlap_scheduler(self): - return self.is_mtp() or self.is_eagle3_one_model( + return self.is_mtp_one_model() or self.is_eagle3_one_model( ) or self.has_draft_model() def support_guided_decoder(self): return self.is_none() or self.has_spec_drafter() def support_capturable_guided_decoder(self): - return self.is_mtp() or self.is_eagle3_one_model() + return self.is_mtp_one_model() or self.is_eagle3_one_model() def has_draft_model(self): - return self.is_eagle3() or self.is_draft_target() + return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle() def needs_kv_cache_recompute(self): """ @@ -75,7 +80,7 @@ def needs_kv_cache_recompute(self): If true, the 1st draft model forward will recompute the kv cache for the accepted draft tokens. """ - return self.is_eagle3() + return self.is_eagle3() or self.is_mtp_eagle() def need_load_draft_weights(self): """ @@ -85,11 +90,12 @@ def need_load_draft_weights(self): return self.is_eagle3_one_model() def has_spec_decoder(self): - return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model() + return self.is_mtp_one_model() or self.is_mtp_eagle() or self.is_eagle3( + ) or self.is_eagle3_one_model() def has_spec_drafter(self): return self.is_eagle3() or self.is_draft_target() or self.is_ngram( - ) or self.is_user_provided() + ) or self.is_user_provided() or self.is_mtp_eagle() def extend_ctx(self, attention_backend: Type[AttentionBackend]): """ diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index cde12417f50..692c1e30a62 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -31,7 +31,7 @@ def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode, Can be used to modify prompts for speculative algorithms that need to update tokens before drafting. """ - if spec_dec_mode.is_eagle3(): + if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle(): # EAGLE3 always throws away the first token when processing draft inputs return input_tokens[1:] return input_tokens diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index de22b9da388..944ad38ee19 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -155,7 +155,7 @@ def all_rank_num_seqs(self): @all_rank_num_seqs.setter def all_rank_num_seqs(self, value: List[int]): self._all_rank_num_seqs = value - if self.spec_dec_mode.is_mtp_eagle(): + if self.spec_dec_mode.is_mtp_eagle_one_model(): self.subseq_all_rank_num_tokens = value def prepare(self): @@ -172,7 +172,7 @@ def prepare(self): # while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft # forward and only one input token in the following draft forward. # This num_tokens is used to set the all_rank_num_tokens for attention dp. - if not self.spec_dec_mode.is_mtp_eagle(): + if not self.spec_dec_mode.is_mtp_eagle_one_model(): self.num_tokens -= self.num_generations if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance @@ -183,7 +183,7 @@ def prepare(self): mtp_slot_ids.append(slot_id) # MTP Vanilla: Update mtp hidden states and past tokens - if self.spec_dec_mode.is_mtp(): + if self.spec_dec_mode.is_mtp_one_model(): mtp_hidden_states_ptrs = [] mtp_past_tokens_ptrs = [] for slot_id in mtp_slot_ids: diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 96fbd83a1cb..5a573b2950a 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -19,7 +19,7 @@ def get_spec_metadata(spec_config, max_num_tokens, spec_resource_manager=None, is_draft_model=False): - if spec_config.spec_dec_mode.is_mtp(): + if spec_config.spec_dec_mode.is_mtp_one_model(): return MTPSpecMetadata( max_draft_len=spec_config.max_draft_len, spec_dec_mode=spec_config.spec_dec_mode, @@ -39,6 +39,21 @@ def get_spec_metadata(spec_config, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, layers_to_capture=spec_config.eagle3_layers_to_capture, + is_mtp_eagle=False, + ) + if spec_config.spec_dec_mode.is_mtp_eagle(): + return Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_num_requests, + num_layers=model_config.num_hidden_layers, + hidden_size=model_config.hidden_size, + max_num_tokens=max_num_tokens, + dtype=model_config.torch_dtype, + is_draft_model=is_draft_model, + eagle3_resource_manager=spec_resource_manager, + layers_to_capture=None, + is_mtp_eagle=True, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelSpecMetadata( @@ -70,7 +85,7 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): max_seq_len = model_engine.max_seq_len max_num_tokens = model_engine.max_num_tokens spec_dec_mode = spec_config.spec_dec_mode - if spec_dec_mode.is_mtp_eagle(): + if spec_dec_mode.is_mtp_eagle_one_model(): if spec_config.use_relaxed_acceptance_for_thinking: return MTPHiddenStatesManager( spec_config, @@ -80,15 +95,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): ) else: return None - if spec_dec_mode.is_mtp(): + if spec_dec_mode.is_mtp_one_model(): return MTPHiddenStatesManager( spec_config, model_config.torch_dtype, model_config.hidden_size, max_num_requests, ) - if spec_dec_mode.is_eagle3(): - assert draft_model_engine is not None, "Draft model engine is required for Eagle3 two model flow." + if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle(): + assert draft_model_engine is not None, "Draft model engine is required for Eagle3 and MTP Eagle two model flow." return Eagle3ResourceManager( spec_config, draft_model_engine.model.config.torch_dtype, @@ -106,10 +121,11 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None): def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: "DecodingBaseConfig"): - if spec_config.spec_dec_mode.is_mtp(): + if spec_config.spec_dec_mode.is_mtp_one_model(): return MTPSampler(sampler_args, nextn=spec_config.num_nextn_predict_layers) - if spec_config.spec_dec_mode.is_eagle3(): + if spec_config.spec_dec_mode.is_eagle3( + ) or spec_config.spec_dec_mode.is_mtp_eagle(): # TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process return TorchSampler(sampler_args) if spec_config.spec_dec_mode.is_eagle3_one_model(): @@ -132,7 +148,8 @@ def get_spec_drafter(model_engine, max_num_requests = model_engine.batch_size if spec_config.spec_dec_mode.is_draft_target( - ) or spec_config.spec_dec_mode.is_eagle3(): + ) or spec_config.spec_dec_mode.is_eagle3( + ) or spec_config.spec_dec_mode.is_mtp_eagle(): return ModelDrafter(spec_config, draft_model_engine, spec_config.max_draft_len, @@ -148,7 +165,7 @@ def get_spec_drafter(model_engine, def get_num_spec_layers(spec_config): - if spec_config.spec_dec_mode.is_mtp(): + if spec_config.spec_dec_mode.is_mtp_one_model(): return spec_config.num_nextn_predict_layers if spec_config.spec_dec_mode.is_eagle3_one_model(): num_eagle_layers = spec_config.num_eagle_layers @@ -160,7 +177,7 @@ def get_spec_worker(spec_config, model_config, mapping): spec_dec_mode = spec_config.spec_dec_mode if spec_dec_mode.is_mtp_vanilla(): return MTPWorker(spec_config, model_config) - if spec_dec_mode.is_mtp_eagle(): + if spec_dec_mode.is_mtp_eagle_one_model(): return MTPEagleWorker(spec_config, model_config) if spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelWorker(spec_config, mapping) @@ -174,14 +191,13 @@ def get_num_extra_kv_tokens(spec_config): """ if spec_config is None: return 0 - if spec_config.spec_dec_mode.is_eagle3_one_model( - ) or spec_config.spec_dec_mode.is_mtp_eagle(): + if spec_config.spec_dec_mode.is_eagle3_one_model(): return spec_config.max_draft_len - 1 return 0 def update_spec_config_from_model_config(spec_config, model_config): - if spec_config.spec_dec_mode.is_mtp(): + if spec_config.spec_dec_mode.is_mtp_one_model(): # Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them. spec_config.max_draft_len = spec_config.num_nextn_predict_layers # Use `num_nextn_predict_layers_from_model_config` to decide decoding mode MTP / MTP_EAGLE. diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 45929ff60d7..ca95ab90184 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -465,7 +465,7 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.EAGLE3 @functools.cached_property - def num_capture_layers(self): + def num_capture_layers(self) -> int: """ Returns the number of layers to capture of the target model. If eagle3_layers_to_capture is not None, return the length of the set. @@ -541,6 +541,7 @@ class MTPDecodingConfig(DecodingBaseConfig): relaxed_topk: int = 1 relaxed_delta: float = 0. use_mtp_vanilla: bool = False + mtp_eagle_one_model: bool = True # TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers` # Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine. @@ -564,11 +565,19 @@ def from_dict(cls, data: dict): def supports_backend(self, backend: str) -> bool: return backend == "pytorch" + @functools.cached_property + def num_capture_layers(self) -> int: + if not self.use_mtp_vanilla and not self.mtp_eagle_one_model: + return 1 + return 0 + @functools.cached_property def spec_dec_mode(self): from tensorrt_llm._torch.speculative.interface import \ SpeculativeDecodingMode as TorchSpeculativeDecodingMode - if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla: + if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and self.mtp_eagle_one_model: + return TorchSpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL + elif self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla and not self.mtp_eagle_one_model: return TorchSpeculativeDecodingMode.MTP_EAGLE return TorchSpeculativeDecodingMode.MTP diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 2a9808b4c6e..594f4146de8 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1913,11 +1913,40 @@ def test_ptp_quickstart_advanced_mtp(llm_root, llm_venv, model_name, "MTP", "--model_dir", f"{llm_models_root()}/{model_path}", + "--use_one_model", ], stdout=running_log) _check_mem_usage(running_log, [54.60, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path", [ + ("DeepSeek-V3-Lite-BF16", "DeepSeek-V3-Lite/bf16"), +]) +def test_ptp_quickstart_advanced_mtp_eagle(llm_root, llm_venv, model_name, + model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd( + [ + str(example_root / "quickstart_advanced.py"), + "--use_cuda_graph", + "--spec_decode_max_draft_len", + "1", # test 1 MTP module + "--spec_decode_algo", + "MTP", + "--model_dir", + f"{llm_models_root()}/{model_path}", + ], + stdout=running_log) + # 74.60 is the memory usage for DeepSeek-V3-Lite-BF16 with MTP Eagle 2 two model style as one extra kv cache is needed for draft model. + _check_mem_usage(running_log, [74.60, 0, 0, 0]) + + @pytest.mark.skip_less_device(4) def test_ptp_quickstart_advanced_bs1(llm_root, llm_venv): model_name = "DeepSeek-V3-Lite-FP8" @@ -2169,6 +2198,7 @@ def test_relaxed_acceptance_quickstart_advanced_deepseek_r1_8gpus( "--relaxed_topk=10", "--relaxed_delta=0.5", "--enable_attention_dp", + "--use_one_model", ], stdout=running_log) _check_mem_usage(running_log, [85.6, 0, 0, 0], 8) From 0716a6be3e373a592fe534e415ffb1073ace1f3f Mon Sep 17 00:00:00 2001 From: qgai Date: Tue, 16 Sep 2025 17:13:16 +0000 Subject: [PATCH 2/2] remove all_rank_max_num_tokens Signed-off-by: qgai --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 12 +++++++----- tensorrt_llm/_torch/models/modeling_speculative.py | 3 --- 2 files changed, 7 insertions(+), 8 deletions(-) mode change 100644 => 100755 tensorrt_llm/_torch/models/modeling_deepseekv3.py mode change 100644 => 100755 tensorrt_llm/_torch/models/modeling_speculative.py diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py old mode 100644 new mode 100755 index ae008978437..2dd93b36627 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -395,8 +395,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, p.data.copy_(module_weights[n][:]) if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( - ) and is_sm_100f() and hasattr( - module, "weight_scale"): + ) and is_sm_100f() and hasattr(module, "weight_scale"): weight, weight_scale = resmooth_to_fp8_e8m0( module.weight, module.weight_scale) transfromed_scale = transform_sf_into_required_layout( @@ -805,8 +804,9 @@ def __init__(self, for key in [EventType.Main, EventType.MoeShared] } - def _compute_shared_expert_tp_size(self, intermediate_size: int, - block_size: int) -> int: + def _compute_shared_expert_tp_size( + self, intermediate_size: int, + block_size: int) -> tuple[int, float | None]: """ In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size. For example, when the intermediate_size is 2048 and block scaling size is 128, @@ -818,7 +818,9 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int, it's 128. For NVFP4, it's 16. Returns: - int: The computed tp_size. + tuple[int, float | None]: A tuple containing (shared_tp_size, shared_output_scale). + - shared_tp_size: The computed TP size. + - shared_output_scale: The output scale factor, or None if not needed. """ assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py old mode 100644 new mode 100755 index 0b92b24df46..6eb989af33f --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -393,7 +393,6 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, all_rank_num_tokens: Optional[List[int]] = None, - all_rank_max_num_tokens: Optional[int] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_states = self.layers( @@ -403,7 +402,6 @@ def forward( embed_tokens=self.embed_tokens, attn_metadata=attn_metadata, all_rank_num_tokens=all_rank_num_tokens, - all_rank_max_num_tokens=all_rank_max_num_tokens, ) return hidden_states @@ -458,7 +456,6 @@ def forward(self, hidden_states=hidden_states, attn_metadata=attn_metadata, all_rank_num_tokens=attn_metadata.all_rank_num_tokens, - all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens, **kwargs) return self.logits_processor.forward( output,