diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 08b747b06a1..100fa57fb2b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -184,11 +184,6 @@ def dispatch_b( (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 ) - # TODO - # masked_m = torch.empty( - # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - # ) - # expected_m = 0 masked_m = expected_m = None return ( @@ -327,6 +322,8 @@ def combine_a( def combine_b(self, output, previous_event): hidden_states, event = self._combine_core(output, previous_event) event.current_stream_wait() if self.async_finish else () + self.handle = None + self.src2dst = None return hidden_states def _combine_core(self, x: torch.Tensor, previous_event): @@ -402,13 +399,6 @@ def dispatch_b( ): hook() if self.return_recv_hook else event.current_stream_wait() - # TODO - # reorder_topk_ids = torch.empty( - # (0,), device=hidden_states.device, dtype=torch.int64 - # ) - # seg_indptr = torch.zeros( - # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - # ) reorder_topk_ids = seg_indptr = None return ( @@ -508,6 +498,7 @@ def _combine_core( return_recv_hook=self.return_recv_hook, ) ) + self.handle = None return combined_hidden_states, event, hook