@@ -184,11 +184,6 @@ def dispatch_b(
184184 (num_experts + 1 ,), device = hidden_states .device , dtype = torch .int64
185185 )
186186
187- # TODO
188- # masked_m = torch.empty(
189- # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
190- # )
191- # expected_m = 0
192187 masked_m = expected_m = None
193188
194189 return (
@@ -327,6 +322,8 @@ def combine_a(
327322 def combine_b (self , output , previous_event ):
328323 hidden_states , event = self ._combine_core (output , previous_event )
329324 event .current_stream_wait () if self .async_finish else ()
325+ self .handle = None
326+ self .src2dst = None
330327 return hidden_states
331328
332329 def _combine_core (self , x : torch .Tensor , previous_event ):
@@ -402,13 +399,6 @@ def dispatch_b(
402399 ):
403400 hook () if self .return_recv_hook else event .current_stream_wait ()
404401
405- # TODO
406- # reorder_topk_ids = torch.empty(
407- # (0,), device=hidden_states.device, dtype=torch.int64
408- # )
409- # seg_indptr = torch.zeros(
410- # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
411- # )
412402 reorder_topk_ids = seg_indptr = None
413403
414404 return (
@@ -508,6 +498,7 @@ def _combine_core(
508498 return_recv_hook = self .return_recv_hook ,
509499 )
510500 )
501+ self .handle = None
511502 return combined_hidden_states , event , hook
512503
513504
0 commit comments