1414from  ..attention_backend .interface  import  (AttentionMask , CustomAttentionMask ,
1515                                           PositionalEmbeddingParams ,
1616                                           PredefinedAttentionMask , RopeParams )
17- from  ..distributed  import  AllReduceParams 
1817from  ..model_config  import  ModelConfig 
1918from  ..modules .attention  import  Attention 
2019from  ..modules .decoder_layer  import  DecoderLayer 
@@ -105,9 +104,6 @@ def forward(
105104        hidden_states : torch .Tensor ,
106105        attn_metadata : AttentionMetadata ,
107106        attention_mask : AttentionMask  =  PredefinedAttentionMask .CAUSAL ,
108-         mrope_config : Optional [dict ] =  None ,
109-         all_reduce_params : Optional [AllReduceParams ] =  None ,
110-         lora_params : Optional [dict ] =  None ,
111107        attention_mask_data : Optional [torch .Tensor ] =  None ,
112108        ** kwargs ,
113109    ) ->  torch .Tensor :
@@ -121,9 +117,6 @@ def forward(
121117                               hidden_states = hidden_states ,
122118                               attn_metadata = attn_metadata ,
123119                               attention_mask = attention_mask ,
124-                                mrope_config = mrope_config ,
125-                                all_reduce_params = all_reduce_params ,
126-                                lora_params = lora_params ,
127120                               attention_window_size = self .attention_window_size ,
128121                               attention_mask_data = attention_mask_data ,
129122                               ** kwargs )
@@ -209,7 +202,6 @@ def forward(
209202        attn_metadata : AttentionMetadata ,
210203        residual : Optional [torch .Tensor ] =  None ,
211204        attention_mask_data : Optional [torch .Tensor ] =  None ,
212-         lora_params : Optional [dict ] =  None ,
213205        ** kwargs ,
214206    ) ->  torch .Tensor :
215207
@@ -222,14 +214,14 @@ def forward(
222214            attention_mask = CustomAttentionMask .CUSTOM  if  attention_mask_data 
223215            is  not None  else  PredefinedAttentionMask .CAUSAL ,
224216            attention_mask_data = attention_mask_data ,
225-             lora_params = lora_params ,
226217            ** kwargs ,
227218        )
228219        hidden_states  =  self .post_attention_layernorm (hidden_states )
229220        hidden_states  =  residual  +  hidden_states 
230221        residual  =  hidden_states 
231222        hidden_states  =  self .pre_feedforward_layernorm (hidden_states )
232-         hidden_states  =  self .mlp (hidden_states , lora_params = lora_params )
223+         hidden_states  =  self .mlp (hidden_states ,
224+                                  lora_params = kwargs .get ("lora_params" , None ))
233225        hidden_states  =  self .post_feedforward_layernorm (hidden_states )
234226        hidden_states  =  residual  +  hidden_states 
235227
@@ -270,7 +262,6 @@ def forward(
270262        inputs_embeds : Optional [torch .FloatTensor ] =  None ,
271263        local_attention_mask_data : Optional [torch .Tensor ] =  None ,
272264        global_attention_mask_data : Optional [torch .Tensor ] =  None ,
273-         lora_params : Optional [dict ] =  None ,
274265        ** kwargs ,
275266    ) ->  torch .Tensor :
276267        if  (input_ids  is  None ) ^  (inputs_embeds  is  not None ):
@@ -291,7 +282,7 @@ def forward(
291282                attention_mask_data = local_attention_mask_data 
292283                if  decoder_layer .self_attn .is_sliding  else 
293284                global_attention_mask_data ,
294-                 lora_params = lora_params ,
285+                 ** kwargs ,
295286            )
296287
297288        hidden_states  =  self .norm (hidden_states )
@@ -465,6 +456,7 @@ def forward(
465456            inputs_embeds = inputs_embeds ,
466457            local_attention_mask_data = local_attention_mask_data ,
467458            global_attention_mask_data = global_attention_mask_data ,
459+             ** kwargs ,
468460        )
469461
470462        return  self .logits_processor .forward (
0 commit comments