Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ py_test_run_all_subdirectory(
"source/serve/doc_code/stable_diffusion.py",
"source/serve/doc_code/object_detection.py",
"source/serve/doc_code/vllm_example.py",
"source/serve/doc_code/cross_node_parallelism_example.py",
"source/serve/doc_code/llm/llm_yaml_config_example.py",
"source/serve/doc_code/llm/qwen_example.py",
],
Expand Down
203 changes: 203 additions & 0 deletions doc/source/serve/doc_code/cross_node_parallelism_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# flake8: noqa
"""
Cross-node parallelism examples for Ray Serve LLM.

TP / PP / custom placement group strategies
for multi-node LLM deployments.
"""

# __cross_node_tp_example_start__
import vllm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Unnecessary Import in Documentation Example

The import vllm in cross_node_parallelism_example.py is unused. Its presence in the documentation examples could mislead users into thinking they need to import vllm directly, even though it's an internal dependency of Ray Serve LLM.

Fix in Cursor Fix in Web

from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with tensor parallelism across 2 GPUs
# Tensor parallelism splits model weights across GPUs
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=2,
)
),
accelerator_type="L4",
engine_kwargs=dict(
tensor_parallel_size=2,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users don't need to specify "ray" as the backend.

max_model_len=8192,
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __cross_node_tp_example_end__

# __cross_node_pp_example_start__
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with pipeline parallelism across 2 GPUs
# Pipeline parallelism splits model layers across GPUs
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=1,
)
),
accelerator_type="L4",
engine_kwargs=dict(
pipeline_parallel_size=2,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

max_model_len=8192,
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __cross_node_pp_example_end__

# __cross_node_tp_pp_example_start__
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with both tensor and pipeline parallelism
# This example uses 4 GPUs total (2 TP * 2 PP)
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=1,
)
),
accelerator_type="L4",
engine_kwargs=dict(
tensor_parallel_size=2,
pipeline_parallel_size=2,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

max_model_len=8192,
enable_chunked_prefill=True,
max_num_batched_tokens=4096,
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __cross_node_tp_pp_example_end__

# __custom_placement_group_pack_example_start__
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with custom placement group using PACK strategy
# PACK tries to place workers on as few nodes as possible for locality
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=1,
)
),
accelerator_type="L4",
engine_kwargs=dict(
tensor_parallel_size=2,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

max_model_len=8192,
),
placement_group_config=dict(
bundles=[{"GPU": 1}] * 2,
strategy="PACK",
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __custom_placement_group_pack_example_end__

# __custom_placement_group_spread_example_start__
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with custom placement group using SPREAD strategy
# SPREAD distributes workers across nodes for fault tolerance
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=1,
)
),
accelerator_type="L4",
engine_kwargs=dict(
tensor_parallel_size=4,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

max_model_len=8192,
),
placement_group_config=dict(
bundles=[{"GPU": 1}] * 4,
strategy="SPREAD",
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __custom_placement_group_spread_example_end__

# __custom_placement_group_strict_pack_example_start__
from ray import serve
from ray.serve.llm import LLMConfig, build_openai_app

# Configure a model with custom placement group using STRICT_PACK strategy
# STRICT_PACK ensures all workers are placed on the same node
llm_config = LLMConfig(
model_loading_config=dict(
model_id="llama-3.1-8b",
model_source="meta-llama/Llama-3.1-8B-Instruct",
),
deployment_config=dict(
autoscaling_config=dict(
min_replicas=1,
max_replicas=2,
)
),
accelerator_type="A100",
engine_kwargs=dict(
tensor_parallel_size=2,
distributed_executor_backend="ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

max_model_len=8192,
),
placement_group_config=dict(
bundles=[{"GPU": 1}] * 2,
strategy="STRICT_PACK",
),
)

# Deploy the application
app = build_openai_app({"llm_configs": [llm_config]})
serve.run(app, blocking=True)
# __custom_placement_group_strict_pack_example_end__
8 changes: 8 additions & 0 deletions doc/source/serve/llm/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,11 @@ serve run config.yaml

For monitoring and observability, see {doc}`Observability <user-guides/observability>`.

## Advanced usage patterns

For each usage pattern, Ray Serve LLM provides a server and client code snippet.

### Cross-node parallelism

Ray Serve LLM supports cross-node tensor parallelism (TP) and pipeline parallelism (PP), allowing you to distribute model inference across multiple GPUs and nodes. See {doc}`Cross-node parallelism <user-guides/cross-node-parallelism>` for a comprehensive guide on configuring and deploying models with cross-node parallelism.

101 changes: 101 additions & 0 deletions doc/source/serve/llm/user-guides/cross-node-parallelism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
(cross-node-parallelism)=
# Cross-node parallelism

Ray Serve LLM supports cross-node tensor parallelism (TP) and pipeline parallelism (PP), allowing you to distribute model inference across multiple GPUs and nodes. This capability enables you to:

- Deploy models that don't fit on a single GPU or node.
- Scale model serving across your cluster's available resources.
- Leverage Ray's placement group strategies to control worker placement for performance or fault tolerance.

::::{note}
By default, Ray Serve LLM uses the `PACK` placement strategy, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. This enables cross-node deployments when single-node resources are insufficient.
::::

## Tensor parallelism

Tensor parallelism splits model weights across multiple GPUs, with each GPU processing a portion of the model's tensors for each forward pass. This approach is useful for models that don't fit on a single GPU.

The following example shows how to configure tensor parallelism across 2 GPUs:

::::{tab-set}

:::{tab-item} Python
:sync: python

```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
:language: python
:start-after: __cross_node_tp_example_start__
:end-before: __cross_node_tp_example_end__
```
:::

::::

## Pipeline parallelism

Pipeline parallelism splits the model's layers across multiple GPUs, with each GPU processing a subset of the model's layers. This approach is useful for very large models where tensor parallelism alone isn't sufficient.

The following example shows how to configure pipeline parallelism across 2 GPUs:

::::{tab-set}

:::{tab-item} Python
:sync: python

```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
:language: python
:start-after: __cross_node_pp_example_start__
:end-before: __cross_node_pp_example_end__
```
:::

::::

## Combined tensor and pipeline parallelism

For extremely large models, you can combine both tensor and pipeline parallelism. The total number of GPUs is the product of `tensor_parallel_size` and `pipeline_parallel_size`.

The following example shows how to configure a model with both TP and PP (4 GPUs total):

::::{tab-set}

:::{tab-item} Python
:sync: python

```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
:language: python
:start-after: __cross_node_tp_pp_example_start__
:end-before: __cross_node_tp_pp_example_end__
```
:::

::::

## Custom placement groups

You can customize how Ray places vLLM engine workers across nodes through the `placement_group_config` parameter. This parameter accepts a dictionary with `bundles` (a list of resource dictionaries) and `strategy` (placement strategy).

Ray Serve LLM uses the `PACK` strategy by default, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. For more details on all available placement strategies, see {ref}`Ray Core's placement strategies documentation <pgroup-strategy>`.

::::{note}
Data parallel deployments automatically override the placement strategy to `STRICT_PACK` because each replica must be co-located for correct data parallel behavior.
::::

While you can specify the degree of tensor and pipeline parallelism, the specific assignment of model ranks to GPUs is managed by the vLLM engine and can't be directly configured through the Ray Serve LLM API. Ray Serve automatically injects accelerator type labels into bundles and merges the first bundle with replica actor resources (CPU, GPU, memory).

The following example shows how to use the `SPREAD` strategy to distribute workers across multiple nodes for fault tolerance:

::::{tab-set}

:::{tab-item} Python
:sync: python

```{literalinclude} ../../doc_code/cross_node_parallelism_example.py
:language: python
:start-after: __custom_placement_group_spread_example_start__
:end-before: __custom_placement_group_spread_example_end__
```
:::

::::

1 change: 1 addition & 0 deletions doc/source/serve/llm/user-guides/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ How-to guides for deploying and configuring Ray Serve LLM features.
```{toctree}
:maxdepth: 1

Cross-node parallelism <cross-node-parallelism>
Deployment Initialization <deployment-initialization>
Prefill/decode disaggregation <prefill-decode>
KV cache offloading <kv-cache-offloading>
Expand Down