11import copy
2+ import os
23from typing import Dict , List , Optional , Tuple , Union
34
45import torch
@@ -337,7 +338,7 @@ def forward(
337338 assert shared_output .size () == routed_output .size (
338339 ), f'unmatched tensor shape'
339340 final_hidden_states = shared_output + routed_output
340- if not self .enable_attention_dp and self .mapping .tp_size > 1 :
341+ if not self .enable_attention_dp and self .mapping .has_tp () :
341342 final_hidden_states = self .all_reduce (
342343 final_hidden_states , all_reduce_params = final_all_reduce_params )
343344
@@ -367,9 +368,6 @@ def __init__(
367368 self .fusion_config = EagerFusionConfig ()
368369 # self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
369370 # )
370- # TODO: re-enable these fusions
371- self .fusion_config .PRE_MOE_FUSION = False
372- self .fusion_config .POST_MLP_FUSION = False
373371
374372 nope_layer = config .no_rope_layers [layer_idx ] == 0
375373 attention_chunk_size = getattr (config , "attention_chunk_size" ,
@@ -387,6 +385,20 @@ def __init__(
387385 self .is_mlp_layer = (layer_idx +
388386 1 ) % config .interleave_moe_layer_step != 0
389387
388+ self .enable_fusion = os .environ .get (
389+ "TRTLLM_LLAMA_EAGER_FUSION_DISABLED" , "0" ) == "0"
390+
391+ if self .is_nvfp4 :
392+ self .pre_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
393+ self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
394+ # TODO: enable fp8 quant fusion later
395+ # elif self.is_fp8_quant:
396+ # self.pre_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
397+ # self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
398+ else :
399+ self .pre_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
400+ self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
401+
390402 if self .is_mlp_layer :
391403 self .feed_forward = GatedMLP (
392404 hidden_size = config .hidden_size ,
@@ -399,8 +411,11 @@ def __init__(
399411 layer_idx = layer_idx ,
400412 )
401413
402- # self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
403- # )
414+ self .fusion_config .PRE_MLP_FUSION = model_config .mapping .has_tp (
415+ ) and not self .enable_attention_dp and self .enable_fusion
416+ self .fusion_config .POST_MLP_FUSION = model_config .mapping .has_tp (
417+ ) and not self .enable_attention_dp and self .enable_fusion
418+
404419 else :
405420 self .feed_forward = Llama4MoE (
406421 num_experts = config .num_local_experts ,
@@ -413,8 +428,10 @@ def __init__(
413428 dtype = config .torch_dtype ,
414429 layer_idx = layer_idx )
415430
416- # self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
417- # )
431+ self .fusion_config .PRE_MOE_FUSION = model_config .mapping .has_tp (
432+ ) and not self .enable_attention_dp and self .enable_fusion
433+ self .fusion_config .POST_MOE_FUSION = model_config .mapping .has_tp (
434+ ) and not self .enable_attention_dp and self .enable_fusion
418435
419436 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
420437 eps = config .rms_norm_eps ,
@@ -432,6 +449,17 @@ def __init__(
432449
433450 self .moe_allreduce = MoEAllReduce (self .mapping )
434451
452+ self .disable_attn_allreduce = (self .fusion_config .PRE_MOE_FUSION
453+ or self .fusion_config .PRE_MLP_FUSION
454+ or self .mapping .tp_size == 1
455+ or self .enable_attention_dp )
456+ self .disable_feed_forward_allreduce = (
457+ self .fusion_config .POST_MOE_FUSION
458+ or self .fusion_config .POST_MLP_FUSION or self .mapping .tp_size == 1
459+ or self .enable_attention_dp )
460+
461+ print (f"init Llama4DecoderLayer" )
462+
435463 def forward (
436464 self ,
437465 position_ids : torch .IntTensor ,
@@ -461,34 +489,43 @@ def forward(
461489 position_ids = position_ids ,
462490 hidden_states = hidden_states ,
463491 attn_metadata = attn_metadata ,
464- all_reduce_params = AllReduceParams (enable_allreduce = not (
465- self .fusion_config .PRE_MOE_FUSION or self .mapping .tp_size == 1
466- or self .enable_attention_dp )),
492+ all_reduce_params = AllReduceParams (
493+ enable_allreduce = not self .disable_attn_allreduce ),
467494 ** kwargs ,
468495 )
469496
470- if self .fusion_config .PRE_MOE_FUSION :
471- hidden_states , residual = self .all_reduce (
497+ if self .is_nvfp4 or self .is_fp8_quant :
498+ scale = self .self_attn .qkv_proj .input_scale
499+ else :
500+ scale = None
501+
502+ if self .fusion_config .PRE_MLP_FUSION or self .fusion_config .PRE_MOE_FUSION :
503+ allreduce_output = self .all_reduce (
472504 hidden_states ,
473505 all_reduce_params = AllReduceParams (
474- fusion_op = AllReduceFusionOp . RESIDUAL_RMS_NORM ,
506+ fusion_op = self . pre_feed_forward_fusion_op ,
475507 residual = residual ,
476508 norm_weight = self .post_attention_layernorm .weight ,
509+ scale = scale ,
477510 eps = self .post_attention_layernorm .variance_epsilon ,
478511 ))
479512 else :
480513 # Fully Connected
481- hidden_states , residual = self .post_attention_layernorm (
514+ allreduce_output = self .post_attention_layernorm (
482515 hidden_states , residual )
483516
517+ if self .is_nvfp4 :
518+ act_fp4 , act_sf , residual = allreduce_output
519+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
520+ else :
521+ hidden_states , residual = allreduce_output
522+
484523 hidden_states = self .feed_forward (
485524 hidden_states ,
486525 all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
487526 all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
488- final_all_reduce_params = AllReduceParams (enable_allreduce = not (
489- self .fusion_config .POST_MOE_FUSION
490- or self .fusion_config .POST_MLP_FUSION
491- or self .mapping .tp_size == 1 or self .enable_attention_dp )),
527+ final_all_reduce_params = AllReduceParams (
528+ enable_allreduce = not self .disable_feed_forward_allreduce ),
492529 cutlass_min_latency_mode = cutlass_min_latency_mode ,
493530 )
494531
@@ -500,16 +537,23 @@ def forward(
500537 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
501538 hidden_states , residual )
502539
503- if (self .fusion_config .POST_MOE_FUSION
540+ if (
541+ self .fusion_config .POST_MOE_FUSION
504542 or self .fusion_config .POST_MLP_FUSION
505- ) and self .next_layer_layernorm is not None :
543+ ) and self .next_layer_layernorm is not None and self .next_attn is not None :
544+ # Get the scale for the next allreduce fusion op
545+ if self .is_nvfp4 or self .is_fp8_quant :
546+ scale = self .next_attn .qkv_proj .input_scale
547+ else :
548+ scale = None
549+
506550 if cutlass_min_latency_mode :
507551 shared_output = hidden_states [0 ]
508552 hidden_states_activated_experts = hidden_states [1 ]
509553 num_activated_experts_per_node = hidden_states [2 ]
510554 experts_to_token_score = hidden_states [3 ]
511555
512- hidden_states , residual = self .moe_allreduce (
556+ allreduce_output = self .moe_allreduce (
513557 residual ,
514558 self .next_layer_layernorm .weight ,
515559 device_num_experts = num_activated_experts_per_node ,
@@ -519,18 +563,30 @@ def forward(
519563 eps = self .next_layer_layernorm .variance_epsilon ,
520564 )
521565 else :
522- hidden_states , residual = self .all_reduce (
523- hidden_states ,
524- all_reduce_params = AllReduceParams (
525- fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
526- residual = residual ,
527- norm_weight = self .next_layer_layernorm .weight ,
528- eps = self .next_layer_layernorm .variance_epsilon ,
529- ))
566+ if self .fusion_config .POST_MLP_FUSION or self .fusion_config .POST_MOE_FUSION :
567+ allreduce_output = self .all_reduce (
568+ hidden_states ,
569+ all_reduce_params = AllReduceParams (
570+ fusion_op = self .post_feed_forward_fusion_op ,
571+ residual = residual ,
572+ norm_weight = self .next_layer_layernorm .weight ,
573+ scale = scale ,
574+ eps = self .next_layer_layernorm .variance_epsilon ,
575+ ))
576+ else :
577+ raise ValueError ("Unknown fusion config" )
530578 elif self .next_layer_layernorm :
531- hidden_states , residual = self .next_layer_layernorm (
579+ print (f"{ self .layer_idx } , { self .next_layer_layernorm } " )
580+ allreduce_output = self .next_layer_layernorm (
532581 hidden_states , residual )
533582
583+ print (f"in forward" )
584+ if self .is_nvfp4 :
585+ act_fp4 , act_sf , residual = allreduce_output
586+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
587+ else :
588+ hidden_states , residual = allreduce_output
589+
534590 return hidden_states , residual
535591
536592
@@ -544,6 +600,14 @@ def __init__(
544600 super ().__init__ ()
545601 config = model_config .pretrained_config
546602 self .layer_idx = layer_idx
603+ self .mapping = model_config .mapping
604+ self .enable_attention_dp = model_config .mapping .enable_attention_dp
605+ self .is_quanted = model_config .quant_config and model_config .quant_config .quant_mode .has_any_quant (
606+ )
607+ self .is_fp8_quant = self .is_quanted and model_config .quant_config .quant_mode .has_fp8_qdq (
608+ )
609+ self .is_nvfp4 = self .is_quanted and model_config .quant_config .quant_mode .has_nvfp4 (
610+ )
547611
548612 self .self_attn = LlamaAttention (
549613 model_config ,
@@ -566,11 +630,43 @@ def __init__(
566630 eps = config .rms_norm_eps ,
567631 dtype = config .torch_dtype )
568632
633+ self .all_reduce = AllReduce (mapping = model_config .mapping )
634+
635+ self .next_layer_layernorm : RMSNorm = None
636+ self .next_attn : LlamaAttention = None
637+
569638 self .attention_mask = PredefinedAttentionMask .CAUSAL
570639 # If the model is being used as an encoder model (prefill only) we use a full attention mask
571640 if not model_config .is_generation :
572641 self .attention_mask = PredefinedAttentionMask .FULL
573642
643+ self .enable_fusion = os .environ .get (
644+ "TRTLLM_LLAMA_EAGER_FUSION_DISABLED" , "0" ) == "0"
645+ self .PRE_MLP_FUSION = self .mapping .has_tp (
646+ ) and not self .enable_attention_dp and self .enable_fusion
647+ self .POST_MLP_FUSION = self .mapping .has_tp () and self .enable_fusion
648+
649+ if self .is_nvfp4 :
650+ self .pre_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
651+ self .post_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM_QUANT_NVFP4
652+ # TODO: enable fp8 quant fusion later
653+ # elif self.is_fp8_quant:
654+ # self.pre_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
655+ # self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
656+ else :
657+ self .pre_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
658+ self .post_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
659+
660+ # TODO: Disable this to avoid large accuracy drop
661+ self .POST_MLP_FUSION = False
662+
663+ self .disable_attn_allreduce = (self .PRE_MLP_FUSION
664+ or self .mapping .tp_size == 1
665+ or self .enable_attention_dp )
666+ self .disable_mlp_allreduce = (self .POST_MLP_FUSION
667+ or self .mapping .tp_size == 1
668+ or self .enable_attention_dp )
669+
574670 def forward (
575671 self ,
576672 position_ids : torch .IntTensor ,
@@ -583,30 +679,78 @@ def forward(
583679 if residual is None :
584680 residual = hidden_states
585681 hidden_states = self .input_layernorm (hidden_states )
586- else :
587- hidden_states , residual = self .input_layernorm (
588- hidden_states , residual )
589682
590683 # Self Attention
591684 hidden_states = self .self_attn (
592685 position_ids = position_ids ,
593686 hidden_states = hidden_states ,
594687 attn_metadata = attn_metadata ,
595688 attention_mask = self .attention_mask ,
689+ all_reduce_params = AllReduceParams (
690+ enable_allreduce = not self .disable_attn_allreduce ),
596691 ** kwargs ,
597692 )
598-
599693 # Fully Connected
600- hidden_states , residual = self .post_attention_layernorm (
601- hidden_states , residual )
602- hidden_states = self .mlp (hidden_states , ** kwargs )
694+ if self .PRE_MLP_FUSION :
695+ if self .is_nvfp4 :
696+ scale = self .mlp .gate_up_proj .input_scale
697+ else :
698+ scale = None
699+ all_reduce_output = self .all_reduce (
700+ hidden_states ,
701+ all_reduce_params = AllReduceParams (
702+ fusion_op = self .pre_mlp_fusion_op ,
703+ residual = residual ,
704+ norm_weight = self .post_attention_layernorm .weight ,
705+ scale = scale ,
706+ eps = self .post_attention_layernorm .variance_epsilon ,
707+ ))
708+ if self .is_nvfp4 :
709+ act_fp4 , act_sf , residual = all_reduce_output
710+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
711+ else :
712+ hidden_states , residual = all_reduce_output
713+ else :
714+ hidden_states , residual = self .post_attention_layernorm (
715+ hidden_states , residual )
716+
717+ hidden_states = self .mlp (
718+ hidden_states ,
719+ final_all_reduce_params = AllReduceParams (
720+ enable_allreduce = not self .disable_mlp_allreduce ),
721+ ** kwargs ,
722+ )
723+
603724 if spec_metadata is not None :
604725 # We save the hidden states in the spec metadata here. In _prepare_draft_tokens,
605726 # PyExecutor will extract these from the model engine's spec metadata.
606727 # They will be passed to the draft model engine on the first draft iteration.
607728 # TODO: can we support multiple model outputs instead?
608729 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
609730 hidden_states , residual )
731+ if self .POST_MLP_FUSION and self .next_attn is not None :
732+ if self .is_nvfp4 :
733+ scale = self .next_attn .qkv_proj .input_scale
734+ else :
735+ scale = None
736+ all_reduce_output = self .all_reduce (
737+ hidden_states ,
738+ all_reduce_params = AllReduceParams (
739+ fusion_op = self .post_mlp_fusion_op ,
740+ residual = residual ,
741+ norm_weight = self .next_layer_layernorm .weight ,
742+ scale = scale ,
743+ eps = self .next_layer_layernorm .variance_epsilon ,
744+ ))
745+ if self .is_nvfp4 :
746+ act_fp4 , act_sf , residual = all_reduce_output
747+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
748+ else :
749+ hidden_states , residual = all_reduce_output
750+ elif self .next_layer_layernorm :
751+ hidden_states , residual = self .next_layer_layernorm (
752+ hidden_states , residual )
753+
610754 return hidden_states , residual
611755
612756
@@ -727,7 +871,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
727871
728872 if self .has_custom_embed_tokens :
729873 with torch .no_grad ():
730- if model_config .mapping .tp_size > 1 :
874+ if model_config .mapping .has_tp () :
731875 weight = split_matrix_tp (
732876 weight ,
733877 model_config .mapping .tp_size ,
@@ -775,7 +919,6 @@ def forward(
775919 lora_params = lora_params ,
776920 )
777921
778- hidden_states , _ = self .norm (hidden_states , residual )
779922 return hidden_states
780923
781924
@@ -788,6 +931,18 @@ def __init__(
788931 ):
789932 super ().__init__ (LlamaModel (model_config ), model_config )
790933
934+ def load_weights (self , weights : Dict ):
935+ super ().load_weights (weights )
936+
937+ for idx , layer in enumerate (
938+ self .model .layers [:self .config .num_hidden_layers ]):
939+ if idx == self .config .num_hidden_layers - 1 :
940+ layer .next_layer_layernorm = self .model .norm
941+ else :
942+ layer .next_layer_layernorm = self .model .layers [
943+ idx + 1 ].input_layernorm
944+ layer .next_attn = self .model .layers [idx + 1 ].self_attn
945+
791946
792947class Llama4InputProcessor (InputProcessor ):
793948
0 commit comments