Skip to content

Commit 70bfd91

Browse files
committed
[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)
Signed-off-by: Alexander Matveev <[email protected]>
1 parent 0c824fc commit 70bfd91

File tree

4 files changed

+83
-22
lines changed

4 files changed

+83
-22
lines changed

examples/offline_inference/basic/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616

1717
def main():
1818
# Create an LLM.
19-
llm = LLM(model="facebook/opt-125m")
19+
llm = LLM(model="deepseek-ai/DeepSeek-R1-0528", tensor_parallel_size=8)
20+
# llm = LLM(
21+
# model="nvidia/DeepSeek-R1-FP4",
22+
# tensor_parallel_size=8,
23+
# quantization="modelopt_fp4",
24+
# )
25+
2026
# Generate texts from the prompts.
2127
# The output is a list of RequestOutput objects
2228
# that contain the prompt, generated text, and other information.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,9 @@ def __init__(
10381038
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
10391039
):
10401040
super().__init__()
1041+
1042+
self.shared_experts_stream = torch.cuda.Stream()
1043+
10411044
if params_dtype is None:
10421045
params_dtype = torch.get_default_dtype()
10431046
self.params_dtype = params_dtype
@@ -1275,6 +1278,10 @@ def __init__(
12751278
def shared_experts(self) -> Optional[torch.nn.Module]:
12761279
return None
12771280

1281+
@property
1282+
def gate(self) -> Optional[torch.nn.Module]:
1283+
return None
1284+
12781285
@property
12791286
def tp_size(self):
12801287
return self.moe_parallel_config.tp_size
@@ -2058,6 +2065,7 @@ def forward_impl_chunked(
20582065
self,
20592066
full_hidden_states: torch.Tensor,
20602067
full_router_logits: torch.Tensor,
2068+
has_separate_shared_experts: bool,
20612069
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
20622070
assert self.batched_hidden_states is not None
20632071
assert self.batched_router_logits is not None
@@ -2106,11 +2114,16 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21062114

21072115
# If there are shared experts but we are not using a modular kernel,
21082116
# the shared experts must be called here
2109-
if (
2110-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2111-
and self.shared_experts is not None
2112-
):
2113-
shared_output = self.shared_experts(staged_hidden_states)
2117+
if has_separate_shared_experts:
2118+
assert self.shared_experts is not None
2119+
2120+
current_stream = torch.cuda.current_stream()
2121+
self.shared_experts_stream.wait_stream(current_stream)
2122+
with torch.cuda.stream(self.shared_experts_stream):
2123+
# Note that staged_hidden_states clone() is necessary
2124+
# here to avoid conflict with the main stream
2125+
shared_output = self.shared_experts(staged_hidden_states.clone())
2126+
21142127
else:
21152128
shared_output = None
21162129

@@ -2137,9 +2150,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
21372150
logical_replica_count=self.logical_replica_count,
21382151
)
21392152

2140-
if shared_output is not None:
2153+
if has_separate_shared_experts:
21412154
assert not isinstance(final_hidden_states, tuple)
21422155
assert self.shared_experts is not None
2156+
2157+
current_stream.wait_stream(self.shared_experts_stream)
2158+
21432159
final_hidden_states = (
21442160
shared_output,
21452161
final_hidden_states,
@@ -2215,12 +2231,31 @@ def forward_impl(
22152231
self.dp_size > 1 and self.use_flashinfer_cutlass_kernels
22162232
)
22172233

2218-
if (
2234+
has_separate_shared_experts = (
2235+
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2236+
and self.shared_experts is not None
2237+
)
2238+
2239+
use_chunked_impl = (
22192240
self.moe_parallel_config.use_pplx_kernels
22202241
or self.moe_parallel_config.use_deepep_ll_kernels
22212242
or _use_flashinfer_cutlass_kernels
2222-
):
2223-
return self.forward_impl_chunked(hidden_states, router_logits)
2243+
)
2244+
2245+
if has_separate_shared_experts and not use_chunked_impl:
2246+
# Start the separate shared experts stream here since we want
2247+
# to run in parallel with the router/gate (next op below)
2248+
current_stream = torch.cuda.current_stream()
2249+
self.shared_experts_stream.wait_stream(current_stream)
2250+
2251+
# If router/gate provided, then apply it here
2252+
if self.gate is not None:
2253+
router_logits, _ = self.gate(hidden_states)
2254+
2255+
if use_chunked_impl:
2256+
return self.forward_impl_chunked(
2257+
hidden_states, router_logits, has_separate_shared_experts
2258+
)
22242259

22252260
do_naive_dispatch_combine: bool = (
22262261
self.dp_size > 1
@@ -2230,11 +2265,14 @@ def forward_impl(
22302265

22312266
# If there are shared experts but we are not using a modular kernel, the
22322267
# shared experts must be called here
2233-
if (
2234-
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
2235-
and self.shared_experts is not None
2236-
):
2237-
shared_output = self.shared_experts(hidden_states)
2268+
if has_separate_shared_experts:
2269+
assert self.shared_experts is not None
2270+
2271+
# Run shared experts in parallel on a separate stream
2272+
with torch.cuda.stream(self.shared_experts_stream):
2273+
# Note that hidden_states clone() is necessary here to avoid
2274+
# conflict with the main stream
2275+
shared_output = self.shared_experts(hidden_states.clone())
22382276
else:
22392277
shared_output = None
22402278

@@ -2275,9 +2313,13 @@ def forward_impl(
22752313
logical_replica_count=self.logical_replica_count,
22762314
)
22772315

2278-
if shared_output is not None:
2316+
if has_separate_shared_experts:
22792317
assert not isinstance(final_hidden_states, tuple)
22802318
assert self.shared_experts is not None
2319+
2320+
# Wait for the parallel shared experts stream to finish here
2321+
current_stream.wait_stream(self.shared_experts_stream)
2322+
22812323
final_hidden_states = (
22822324
shared_output,
22832325
final_hidden_states,

vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,31 @@ class SharedFusedMoE(FusedMoE):
1919
def __init__(
2020
self,
2121
shared_experts: torch.nn.Module,
22+
gate: Optional[torch.nn.Module] = None,
2223
use_overlapped: bool = True,
2324
**kwargs,
2425
):
2526
super().__init__(**kwargs)
2627
self._shared_experts = shared_experts
28+
self._gate = gate
2729
self.use_overlapped = use_overlapped
2830

2931
@property
3032
def shared_experts(self) -> Optional[torch.nn.Module]:
3133
return self._shared_experts if self.use_overlapped else None
3234

35+
@property
36+
def gate(self) -> Optional[torch.nn.Module]:
37+
return self._gate if self.use_overlapped else None
38+
3339
def forward(
3440
self,
3541
hidden_states: torch.Tensor,
3642
router_logits: torch.Tensor,
3743
) -> tuple[torch.Tensor, torch.Tensor]:
3844
if not self.use_overlapped:
45+
assert self.gate is None
46+
3947
shared_out = self._shared_experts(hidden_states)
4048

4149
# Reduce outputs if necessary, since the MLP should

vllm/model_executor/models/deepseek_v2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def __init__(
241241

242242
self.experts = SharedFusedMoE(
243243
shared_experts=self.shared_experts,
244+
gate=self.gate,
244245
num_experts=config.n_routed_experts,
245246
top_k=config.num_experts_per_tok,
246247
hidden_size=config.hidden_size,
@@ -272,12 +273,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
272273
if self.is_sequence_parallel:
273274
hidden_states = sequence_parallel_chunk(hidden_states)
274275

275-
# router_logits: (num_tokens, n_experts)
276-
router_logits, _ = self.gate(hidden_states)
277-
278-
fused_moe_out = self.experts(
279-
hidden_states=hidden_states, router_logits=router_logits
280-
)
276+
if isinstance(self.experts, SharedFusedMoE):
277+
fused_moe_out = self.experts(
278+
hidden_states=hidden_states, router_logits=hidden_states
279+
)
280+
else:
281+
# router_logits: (num_tokens, n_experts)
282+
router_logits, _ = self.gate(hidden_states)
283+
fused_moe_out = self.experts(
284+
hidden_states=hidden_states, router_logits=router_logits
285+
)
281286

282287
if self.shared_experts is not None:
283288
shared_output, final_hidden_states = fused_moe_out

0 commit comments

Comments
 (0)