Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
Expand Down
23 changes: 14 additions & 9 deletions examples/offline_inference/torchrun_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""

import torch.distributed as dist

from vllm import LLM, SamplingParams

# Create prompts, the same across all ranks
Expand All @@ -27,23 +29,26 @@
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="facebook/opt-125m",
model="meta-llama/Llama-3.1-8B",
tensor_parallel_size=2,
pipeline_parallel_size=2,
distributed_executor_backend="external_launcher",
seed=0,
max_model_len=32768,
seed=1,
)

outputs = llm.generate(prompts, sampling_params)

# all ranks will have the same outputs
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}")
if dist.get_rank() == 0:
print("-" * 50)
"""
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:

1. to communicate control messages across all ranks, use the cpu group,
Expand Down
3 changes: 2 additions & 1 deletion tests/distributed/test_torchrun_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

# unit test for `examples/offline_inference/torchrun_example.py`

import os
import random

import torch.distributed as dist
Expand All @@ -25,6 +25,7 @@
# to test if all ranks agree on the same kv cache configuration.
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2,
pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)),
distributed_executor_backend="external_launcher",
gpu_memory_utilization=random.uniform(0.7, 0.9),
swap_space=random.randint(1, 4),
Expand Down
1 change: 0 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,6 @@ class ParallelConfig:
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""

max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
Expand Down
6 changes: 4 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:

def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
Expand Down Expand Up @@ -298,4 +299,5 @@ def free_shared_buffer(pointers: list[int],
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
ops.free_shared_buffer(pointers[rank])
if ops is not None:
ops.free_shared_buffer(pointers[rank])
5 changes: 3 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,9 +1376,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
return False

if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend not in ["ray", "mp"]):
and self.distributed_executor_backend
Comment thread
luccafong marked this conversation as resolved.
not in ("ray", "mp", "external_launcher")):
name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor"
"or multiprocessing executor or external launcher"
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False

Expand Down
3 changes: 0 additions & 3 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
("ExecutorWithExternalLauncher does not "
"support pipeline parallelism.")
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
Expand Down
33 changes: 27 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
Expand Down Expand Up @@ -1168,13 +1169,32 @@ def execute_model(
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output

# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
return hidden_states

sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we need this broadcast?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added as comments, now we enable by sync all ranks, will improve to reduce pp bubles in following PR.

model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
Expand All @@ -1192,6 +1212,7 @@ def execute_model(
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def execute_model(

output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)

if not get_pp_group().is_last_rank:
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None

assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None

Expand Down