-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 #17827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d7d6071
e4d6671
1c95626
157c747
0a4da1d
dca6e00
f4e29f3
9472df9
352404e
0546258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,8 @@ | |
| has_kv_transfer_group) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 | ||
| from vllm.distributed.parallel_state import ( | ||
| get_pp_group, graph_capture, prepare_communication_buffer_for_model) | ||
| get_pp_group, get_tp_group, graph_capture, | ||
| prepare_communication_buffer_for_model) | ||
| from vllm.forward_context import get_forward_context, set_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding | ||
|
|
@@ -1168,13 +1169,32 @@ def execute_model( | |
| hidden_states, aux_hidden_states = model_output | ||
| else: | ||
| hidden_states = model_output | ||
|
|
||
| # Broadcast PP output for external_launcher (torchrun) | ||
| # to make sure we are synced across pp ranks | ||
| # TODO: Support overlapping mirco-batches | ||
| # https://github.com/vllm-project/vllm/issues/18019 | ||
| broadcast_pp_output = \ | ||
| self.parallel_config.distributed_executor_backend \ | ||
| == "external_launcher" and len(get_pp_group().ranks) > 0 | ||
| if not get_pp_group().is_last_rank: | ||
| # For mid-pipeline stages, return the hidden states. | ||
| return hidden_states | ||
|
|
||
| sample_hidden_states = hidden_states[logits_indices] | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
| if not broadcast_pp_output: | ||
| return hidden_states | ||
| assert isinstance(hidden_states, IntermediateTensors) | ||
| get_pp_group().send_tensor_dict(hidden_states.tensors, | ||
| all_gather_group=get_tp_group()) | ||
| logits = None | ||
| else: | ||
| sample_hidden_states = hidden_states[logits_indices] | ||
| logits = self.model.compute_logits(sample_hidden_states, None) | ||
| if broadcast_pp_output: | ||
| model_output_broadcast_data = { | ||
| "logits": logits.contiguous(), | ||
| } if logits is not None else {} | ||
| model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain why we need this broadcast?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, added as comments, now we enable by sync all ranks, will improve to reduce pp bubles in following PR. |
||
| model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) | ||
| assert model_output_broadcast_data is not None | ||
| logits = model_output_broadcast_data["logits"] | ||
|
|
||
| # Apply structured output bitmasks if present | ||
| if scheduler_output.grammar_bitmask is not None: | ||
|
|
@@ -1192,6 +1212,7 @@ def execute_model( | |
| # creates a new tensor with separate storage from the original | ||
| # logits tensor. This means any in-place operations on bonus_logits | ||
| # won't affect the original logits tensor. | ||
| assert logits is not None | ||
| bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] | ||
| sampler_output = self.sampler( | ||
| logits=bonus_logits, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.