@@ -214,7 +214,9 @@ def forward(
214214 if residual is None :
215215 residual = hidden_states
216216 hidden_states = self .input_layernorm (hidden_states )
217-
217+ if spec_metadata is not None and spec_metadata .is_layer_capture (
218+ self .layer_idx ):
219+ self .fusion_config .POST_MOE_FUSION = False
218220 # Self Attention
219221 hidden_states = self .self_attn (
220222 position_ids = position_ids ,
@@ -257,9 +259,6 @@ def forward(
257259
258260 if self .fusion_config .POST_MOE_FUSION :
259261 if do_finalize :
260- if spec_metadata :
261- spec_metadata .maybe_capture_hidden_states (
262- self .layer_idx , hidden_states , residual )
263262 hidden_states , residual = self .allreduce (
264263 hidden_states ,
265264 all_reduce_params = AllReduceParams (
@@ -289,12 +288,8 @@ def forward(
289288 hidden_states , residual = self .moe_allreduce (
290289 fc2_output , all_reduce_params = moe_all_reduce_params )
291290
292- if spec_metadata :
293- spec_metadata .maybe_capture_hidden_states (
294- self .layer_idx , hidden_states , residual )
295-
296291 else :
297- if spec_metadata :
292+ if spec_metadata and spec_metadata . is_layer_capture ( self . layer_idx ) :
298293 spec_metadata .maybe_capture_hidden_states (
299294 self .layer_idx , hidden_states , residual )
300295 if self .next_layer_layernorm is not None :
0 commit comments