Skip to content

Commit 18bf91e

Browse files
wip
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 00f526f commit 18bf91e

File tree

4 files changed

+17
-5
lines changed

4 files changed

+17
-5
lines changed

examples/basic-ub.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def main():
4040
max_model_len=1024,
4141
#load_format="dummy",
4242
###############
43-
tensor_parallel_size=1,
44-
#data_parallel_size=2,
45-
enable_expert_parallel=False,
43+
#tensor_parallel_size=1,
44+
data_parallel_size=2,
45+
enable_expert_parallel=True,
4646
###############
47-
enable_microbatching=True,
47+
#enable_microbatching=True,
4848
)
4949
# Generate texts from the prompts.
5050
# The output is a list of RequestOutput objects

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4332,7 +4332,7 @@ def __post_init__(self):
43324332
logger.warning_once(
43334333
"Piecewise compilation is not supported with "
43344334
"microbatching. Disabling piecewiseching compilation.")
4335-
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
4335+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
43364336

43374337

43384338
if self.model_config and self.model_config.use_mla and \

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
88
from vllm.model_executor.layers.fused_moe.utils import (
99
moe_kernel_quantize_input)
10+
from vllm.v1.worker.ubatching import get_current_ubatch_context, yield_impl
1011

1112

1213
# Note use: layer.get_all_to_all() to get an AllToAll instance
@@ -117,7 +118,11 @@ def dispatch(send: bool):
117118
do_send=send,
118119
do_recv=not send,
119120
)
121+
122+
# if ubatch_ctx is not None:
123+
# ubatch_ctx.gpu_stream_wait()
120124
dispatch(True) # Send
125+
yield_impl(gpu_wait=False)
121126
dispatch(False) # Recv
122127

123128
return expert_x, expert_x_scale, expert_num_tokens
@@ -155,5 +160,8 @@ def combine(send: bool):
155160
do_send=send,
156161
do_recv=not send,
157162
)
163+
# if ubatch_ctx is not None:
164+
# ubatch_ctx.gpu_stream_wait()
158165
combine(True)
166+
yield_impl(gpu_wait=False)
159167
combine(False)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
default_weight_loader, maybe_remap_kv_scale_name)
5050
from vllm.model_executor.sampling_metadata import SamplingMetadata
5151
from vllm.sequence import IntermediateTensors
52+
from vllm.v1.worker.ubatching import get_current_ubatch_context
5253

5354
from .interfaces import SupportsPP
5455
from .utils import (PPMissingLayer, is_pp_missing_parameter,
@@ -656,6 +657,9 @@ def forward(
656657
intermediate_tensors: Optional[IntermediateTensors],
657658
inputs_embeds: Optional[torch.Tensor] = None,
658659
) -> Union[torch.Tensor, IntermediateTensors]:
660+
if ubatch_ctx := get_current_ubatch_context() is not None:
661+
print("in forward, ubatch:", ubatch_ctx.id)
662+
659663
if get_pp_group().is_first_rank:
660664
if inputs_embeds is not None:
661665
hidden_states = inputs_embeds

0 commit comments

Comments
 (0)