@@ -1038,6 +1038,9 @@ def __init__(
10381038 expert_mapping : Optional [list [tuple [str , str , int , str ]]] = None ,
10391039 ):
10401040 super ().__init__ ()
1041+
1042+ self .shared_experts_stream = torch .cuda .Stream ()
1043+
10411044 if params_dtype is None :
10421045 params_dtype = torch .get_default_dtype ()
10431046 self .params_dtype = params_dtype
@@ -1275,6 +1278,10 @@ def __init__(
12751278 def shared_experts (self ) -> Optional [torch .nn .Module ]:
12761279 return None
12771280
1281+ @property
1282+ def gate (self ) -> Optional [torch .nn .Module ]:
1283+ return None
1284+
12781285 @property
12791286 def tp_size (self ):
12801287 return self .moe_parallel_config .tp_size
@@ -2058,6 +2065,7 @@ def forward_impl_chunked(
20582065 self ,
20592066 full_hidden_states : torch .Tensor ,
20602067 full_router_logits : torch .Tensor ,
2068+ has_separate_shared_experts : bool ,
20612069 ) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
20622070 assert self .batched_hidden_states is not None
20632071 assert self .batched_router_logits is not None
@@ -2106,11 +2114,16 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21062114
21072115 # If there are shared experts but we are not using a modular kernel,
21082116 # the shared experts must be called here
2109- if (
2110- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2111- and self .shared_experts is not None
2112- ):
2113- shared_output = self .shared_experts (staged_hidden_states )
2117+ if has_separate_shared_experts :
2118+ assert self .shared_experts is not None
2119+
2120+ current_stream = torch .cuda .current_stream ()
2121+ self .shared_experts_stream .wait_stream (current_stream )
2122+ with torch .cuda .stream (self .shared_experts_stream ):
2123+ # Note that staged_hidden_states clone() is necessary
2124+ # here to avoid conflict with the main stream
2125+ shared_output = self .shared_experts (staged_hidden_states .clone ())
2126+
21142127 else :
21152128 shared_output = None
21162129
@@ -2137,9 +2150,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21372150 logical_replica_count = self .logical_replica_count ,
21382151 )
21392152
2140- if shared_output is not None :
2153+ if has_separate_shared_experts :
21412154 assert not isinstance (final_hidden_states , tuple )
21422155 assert self .shared_experts is not None
2156+
2157+ current_stream .wait_stream (self .shared_experts_stream )
2158+
21432159 final_hidden_states = (
21442160 shared_output ,
21452161 final_hidden_states ,
@@ -2215,12 +2231,31 @@ def forward_impl(
22152231 self .dp_size > 1 and self .use_flashinfer_cutlass_kernels
22162232 )
22172233
2218- if (
2234+ has_separate_shared_experts = (
2235+ not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2236+ and self .shared_experts is not None
2237+ )
2238+
2239+ use_chunked_impl = (
22192240 self .moe_parallel_config .use_pplx_kernels
22202241 or self .moe_parallel_config .use_deepep_ll_kernels
22212242 or _use_flashinfer_cutlass_kernels
2222- ):
2223- return self .forward_impl_chunked (hidden_states , router_logits )
2243+ )
2244+
2245+ if has_separate_shared_experts and not use_chunked_impl :
2246+ # Start the separate shared experts stream here since we want
2247+ # to run in parallel with the router/gate (next op below)
2248+ current_stream = torch .cuda .current_stream ()
2249+ self .shared_experts_stream .wait_stream (current_stream )
2250+
2251+ # If router/gate provided, then apply it here
2252+ if self .gate is not None :
2253+ router_logits , _ = self .gate (hidden_states )
2254+
2255+ if use_chunked_impl :
2256+ return self .forward_impl_chunked (
2257+ hidden_states , router_logits , has_separate_shared_experts
2258+ )
22242259
22252260 do_naive_dispatch_combine : bool = (
22262261 self .dp_size > 1
@@ -2230,11 +2265,14 @@ def forward_impl(
22302265
22312266 # If there are shared experts but we are not using a modular kernel, the
22322267 # shared experts must be called here
2233- if (
2234- not isinstance (self .quant_method .fused_experts , FusedMoEModularKernel )
2235- and self .shared_experts is not None
2236- ):
2237- shared_output = self .shared_experts (hidden_states )
2268+ if has_separate_shared_experts :
2269+ assert self .shared_experts is not None
2270+
2271+ # Run shared experts in parallel on a separate stream
2272+ with torch .cuda .stream (self .shared_experts_stream ):
2273+ # Note that hidden_states clone() is necessary here to avoid
2274+ # conflict with the main stream
2275+ shared_output = self .shared_experts (hidden_states .clone ())
22382276 else :
22392277 shared_output = None
22402278
@@ -2275,9 +2313,13 @@ def forward_impl(
22752313 logical_replica_count = self .logical_replica_count ,
22762314 )
22772315
2278- if shared_output is not None :
2316+ if has_separate_shared_experts :
22792317 assert not isinstance (final_hidden_states , tuple )
22802318 assert self .shared_experts is not None
2319+
2320+ # Wait for the parallel shared experts stream to finish here
2321+ current_stream .wait_stream (self .shared_experts_stream )
2322+
22812323 final_hidden_states = (
22822324 shared_output ,
22832325 final_hidden_states ,
0 commit comments