-
Notifications
You must be signed in to change notification settings - Fork 7k
[docs][serve][llm] examples and doc for cross-node TP/PP in Serve #57715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
1823990
20a9603
1f15be4
20b50d7
0afabf1
a60e5e9
bcf73a1
835a18a
173d860
94b71ab
2c8ad68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| # flake8: noqa | ||
| """ | ||
| Cross-node parallelism examples for Ray Serve LLM. | ||
|
|
||
| TP / PP / custom placement group strategies | ||
| for multi-node LLM deployments. | ||
| """ | ||
|
|
||
| # __cross_node_tp_example_start__ | ||
| import vllm | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with tensor parallelism across 2 GPUs | ||
| # Tensor parallelism splits model weights across GPUs | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=2, | ||
| ) | ||
| ), | ||
| accelerator_type="L4", | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=2, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __cross_node_tp_example_end__ | ||
|
|
||
| # __cross_node_pp_example_start__ | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with pipeline parallelism across 2 GPUs | ||
| # Pipeline parallelism splits model layers across GPUs | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=1, | ||
| ) | ||
| ), | ||
| accelerator_type="L4", | ||
| engine_kwargs=dict( | ||
| pipeline_parallel_size=2, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __cross_node_pp_example_end__ | ||
|
|
||
| # __cross_node_tp_pp_example_start__ | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with both tensor and pipeline parallelism | ||
| # This example uses 4 GPUs total (2 TP * 2 PP) | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=1, | ||
| ) | ||
| ), | ||
| accelerator_type="L4", | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=2, | ||
| pipeline_parallel_size=2, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| enable_chunked_prefill=True, | ||
| max_num_batched_tokens=4096, | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __cross_node_tp_pp_example_end__ | ||
|
|
||
| # __custom_placement_group_pack_example_start__ | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with custom placement group using PACK strategy | ||
| # PACK tries to place workers on as few nodes as possible for locality | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=1, | ||
| ) | ||
| ), | ||
| accelerator_type="L4", | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=2, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| ), | ||
| placement_group_config=dict( | ||
| bundles=[{"GPU": 1}] * 2, | ||
| strategy="PACK", | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __custom_placement_group_pack_example_end__ | ||
|
|
||
| # __custom_placement_group_spread_example_start__ | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with custom placement group using SPREAD strategy | ||
| # SPREAD distributes workers across nodes for fault tolerance | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=1, | ||
| ) | ||
| ), | ||
| accelerator_type="L4", | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=4, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| ), | ||
| placement_group_config=dict( | ||
| bundles=[{"GPU": 1}] * 4, | ||
| strategy="SPREAD", | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __custom_placement_group_spread_example_end__ | ||
|
|
||
| # __custom_placement_group_strict_pack_example_start__ | ||
| from ray import serve | ||
| from ray.serve.llm import LLMConfig, build_openai_app | ||
|
|
||
| # Configure a model with custom placement group using STRICT_PACK strategy | ||
| # STRICT_PACK ensures all workers are placed on the same node | ||
| llm_config = LLMConfig( | ||
| model_loading_config=dict( | ||
| model_id="llama-3.1-8b", | ||
| model_source="meta-llama/Llama-3.1-8B-Instruct", | ||
| ), | ||
| deployment_config=dict( | ||
| autoscaling_config=dict( | ||
| min_replicas=1, | ||
| max_replicas=2, | ||
| ) | ||
| ), | ||
| accelerator_type="A100", | ||
| engine_kwargs=dict( | ||
| tensor_parallel_size=2, | ||
| distributed_executor_backend="ray", | ||
|
||
| max_model_len=8192, | ||
| ), | ||
| placement_group_config=dict( | ||
| bundles=[{"GPU": 1}] * 2, | ||
| strategy="STRICT_PACK", | ||
| ), | ||
| ) | ||
|
|
||
| # Deploy the application | ||
| app = build_openai_app({"llm_configs": [llm_config]}) | ||
| serve.run(app, blocking=True) | ||
| # __custom_placement_group_strict_pack_example_end__ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| (cross-node-parallelism)= | ||
| # Cross-node parallelism | ||
|
|
||
| Ray Serve LLM supports cross-node tensor parallelism (TP) and pipeline parallelism (PP), allowing you to distribute model inference across multiple GPUs and nodes. This capability enables you to: | ||
|
|
||
| - Deploy models that don't fit on a single GPU or node. | ||
| - Scale model serving across your cluster's available resources. | ||
| - Leverage Ray's placement group strategies to control worker placement for performance or fault tolerance. | ||
|
|
||
| ::::{note} | ||
| By default, Ray Serve LLM uses the `PACK` placement strategy, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. This enables cross-node deployments when single-node resources are insufficient. | ||
| :::: | ||
|
|
||
| ## Tensor parallelism | ||
|
|
||
| Tensor parallelism splits model weights across multiple GPUs, with each GPU processing a portion of the model's tensors for each forward pass. This approach is useful for models that don't fit on a single GPU. | ||
|
|
||
| The following example shows how to configure tensor parallelism across 2 GPUs: | ||
|
|
||
| ::::{tab-set} | ||
|
|
||
| :::{tab-item} Python | ||
| :sync: python | ||
|
|
||
| ```{literalinclude} ../../doc_code/cross_node_parallelism_example.py | ||
| :language: python | ||
| :start-after: __cross_node_tp_example_start__ | ||
| :end-before: __cross_node_tp_example_end__ | ||
| ``` | ||
| ::: | ||
|
|
||
| :::: | ||
|
|
||
| ## Pipeline parallelism | ||
|
|
||
| Pipeline parallelism splits the model's layers across multiple GPUs, with each GPU processing a subset of the model's layers. This approach is useful for very large models where tensor parallelism alone isn't sufficient. | ||
|
|
||
| The following example shows how to configure pipeline parallelism across 2 GPUs: | ||
|
|
||
| ::::{tab-set} | ||
|
|
||
| :::{tab-item} Python | ||
| :sync: python | ||
|
|
||
| ```{literalinclude} ../../doc_code/cross_node_parallelism_example.py | ||
| :language: python | ||
| :start-after: __cross_node_pp_example_start__ | ||
| :end-before: __cross_node_pp_example_end__ | ||
| ``` | ||
| ::: | ||
|
|
||
| :::: | ||
|
|
||
| ## Combined tensor and pipeline parallelism | ||
|
|
||
| For extremely large models, you can combine both tensor and pipeline parallelism. The total number of GPUs is the product of `tensor_parallel_size` and `pipeline_parallel_size`. | ||
|
|
||
| The following example shows how to configure a model with both TP and PP (4 GPUs total): | ||
|
|
||
| ::::{tab-set} | ||
|
|
||
| :::{tab-item} Python | ||
| :sync: python | ||
|
|
||
| ```{literalinclude} ../../doc_code/cross_node_parallelism_example.py | ||
| :language: python | ||
| :start-after: __cross_node_tp_pp_example_start__ | ||
| :end-before: __cross_node_tp_pp_example_end__ | ||
| ``` | ||
| ::: | ||
|
|
||
| :::: | ||
|
|
||
| ## Custom placement groups | ||
|
|
||
| You can customize how Ray places vLLM engine workers across nodes through the `placement_group_config` parameter. This parameter accepts a dictionary with `bundles` (a list of resource dictionaries) and `strategy` (placement strategy). | ||
|
|
||
| Ray Serve LLM uses the `PACK` strategy by default, which tries to place workers on as few nodes as possible. If workers can't fit on a single node, they automatically spill to other nodes. For more details on all available placement strategies, see {ref}`Ray Core's placement strategies documentation <pgroup-strategy>`. | ||
|
|
||
| ::::{note} | ||
| Data parallel deployments automatically override the placement strategy to `STRICT_PACK` because each replica must be co-located for correct data parallel behavior. | ||
| :::: | ||
|
|
||
| While you can specify the degree of tensor and pipeline parallelism, the specific assignment of model ranks to GPUs is managed by the vLLM engine and can't be directly configured through the Ray Serve LLM API. Ray Serve automatically injects accelerator type labels into bundles and merges the first bundle with replica actor resources (CPU, GPU, memory). | ||
|
|
||
| The following example shows how to use the `SPREAD` strategy to distribute workers across multiple nodes for fault tolerance: | ||
|
|
||
| ::::{tab-set} | ||
|
|
||
| :::{tab-item} Python | ||
| :sync: python | ||
|
|
||
| ```{literalinclude} ../../doc_code/cross_node_parallelism_example.py | ||
| :language: python | ||
| :start-after: __custom_placement_group_spread_example_start__ | ||
| :end-before: __custom_placement_group_spread_example_end__ | ||
| ``` | ||
| ::: | ||
|
|
||
| :::: | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Unnecessary Import in Documentation Example
The
import vllmincross_node_parallelism_example.pyis unused. Its presence in the documentation examples could mislead users into thinking they need to importvllmdirectly, even though it's an internal dependency of Ray Serve LLM.