Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] LLM APIs for Ray Data and Ray Serve #50639

Open
richardliaw opened this issue Feb 16, 2025 · 10 comments
Open

[RFC] LLM APIs for Ray Data and Ray Serve #50639

richardliaw opened this issue Feb 16, 2025 · 10 comments

Comments

@richardliaw
Copy link
Contributor

RFC: LLM APIs for Ray Data and Ray Serve

Summary

This RFC proposes new APIs in Ray for leveraging Large Language Models (LLMs) effectively within the Ray ecosystem, specifically introducing integrations for Ray Serve and Ray Data with vLLM and OpenAI.

Motivation

As LLMs become increasingly central to modern AI infrastructure deployments, platforms require the ability to deploy and scale these models efficiently. Current Ray Data and Ray Serve have limited support for LLM deployments, where users have to manually configure and manage the underlying LLM engine.

This proposal aims to address these challenges by providing unified, production-ready APIs for both batch processing and serving of LLMs within Ray in ray.data.llm and ray.serve.llm.

Ray Data LLM

ray.data.llm introduces several key components:

  1. build_llm_processor: Unified API for constructing processors
  2. Processors: User-facing API for specific LLM functionalities, including integration with vLLM and endpoint-based deployments
  3. ProcessorConfig: Configuration interface for Processors

Design Principles:

  • Integrate seamlessly with existing Ray Data APIs
  • One processor contains at most one LLM engine
  • Configurable but with sensible defaults for optimal throughput
import ray
from ray.data.llm import build_llm_processor, VLLMProcessorConfig

processor_config = VLLMProcessorConfig(
    model="meta-llama/Llama-3.1-8B-Instruct",
)

processor = build_llm_processor(
    processor_config,
    preprocess=lambda row: dict(
        messages=row["question"],
        sampling_params=dict(
            temperature=0.3,
            max_tokens=250,
        )
    ),
    postprocess=lambda row: dict(
        answer=row["generated_text"]
    ),
    concurrency=4,
)

ds = ray.data.read_parquet(...)
ds = processor(ds)
ds.write_parquet(...)

You can also make calls to deployed models that have an OpenAI compatible API endpoint.

from ray.data.llm import HTTPRequestProcessorConfig

OPENAI_KEY = "..."
ds = ray.data.read_parquet("...")

ds = build_llm_processor(
    HTTPRequestProcessorConfig(
        url="https://api.openai.com/v1/chat/completions",
        header=f"Authorization: Bearer {OPENAI_KEY}",
        qps=1,
    ),
    preprocess=lambda row: dict(
        model="gpt-4o-mini",
        messages=row["messages"],
        sampling_params=dict(
            temperature=0.0,
            max_tokens=150,
        ),
    ),
    postprocess=lambda row: dict(
        resp=row["generated_text"]
    ),
    concurrency=8,
)(ds)

ds.write_parquet("...")

Ray Serve LLM

The new ray.serve.llm provides:

  1. VLLMDeployment: Manages VLLM engine deployment
  2. LLMModelRouter: OpenAI-compatible API router
  3. LLMConfig: Unified configuration for model deployment
  4. LoRA Support: Multi-adapter sharing with LRU caching

These features allow users to deploy multiple LLM models together with a familiar Ray Serve API, while providing compatibility with the OpenAI API.

from ray import serve
from ray.serve.llm import VLLMDeployment, LLMConfig, ModelLoadingConfig

llm_config = LLMConfig(
    model_loading_config=ModelLoadingConfig(
        served_model_name="llama-3.1-8b",
        model_source="meta-llama/Llama-3.1-8b-instruct",
    ),
    deployment_config=DeploymentConfig(
        autoscaling_config=AutoscalingConfig(
            min_replicas=1,
            max_replicas=8,
        )
    ),
)

vllm_deployment = VLLMDeployment.options(**llm_config.get_serve_options()).bind(llm_config)
serve.run(vllm_deployment)

Below is a more comprehensive example of using the OpenAI API with Ray Serve.

from ray import serve
from ray.serve.llm import LLMModelRouterDeployment, VLLMDeployment, LLMConfig, ModelLoadingConfig

# Configure multiple models
llm_config1 = LLMConfig(
    model_loading_config=ModelLoadingConfig(
        served_model_name="llama-3.1-8b",
        model_source="meta-llama/Llama-3.1-8b-instruct",
    ),
    deployment_config=DeploymentConfig(
        autoscaling_config=AutoscalingConfig(
            min_replicas=1,
            max_replicas=8,
        )
    ),
)

llm_config2 = LLMConfig(
    model_loading_config=ModelLoadingConfig(
        served_model_name="llama-3.2-3b",
        model_source="meta-llama/Llama-3.2-3b-instruct",
    ),
    deployment_config=DeploymentConfig(
        autoscaling_config=AutoscalingConfig(
            min_replicas=1,
            max_replicas=8,
        )
    ),
)

# Create deployments
vllm_deployment1 = VLLMDeployment.options(**llm_config1.get_serve_options()).bind(llm_config1)
vllm_deployment2 = VLLMDeployment.options(**llm_config2.get_serve_options()).bind(llm_config2)

# Create router deployment
llm_app = LLMModelRouterDeployment.options().bind([vllm_deployment1, vllm_deployment2])

# Deploy the application
serve.run(llm_app)

And you can now use an OpenAI API client to interact with the deployed models.

from openai import OpenAI

# Initialize client with your deployment endpoint
client = OpenAI(
    base_url="http://localhost:8000",
    api_key="fake-key"  # The API key is not validated
)

# Chat completion
chat_response = client.chat.completions.create(
    model="llama-3.1-8b",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is the capital of France?"}
    ],
    temperature=0.7,
    max_tokens=150
)

# Text completion
completion_response = client.completions.create(
    model="llama-3.2-3b",
    prompt="Write a poem about",
    temperature=0.7,
    max_tokens=150
)

Advanced Example: RAG Pipeline with ray.data.llm

Here is a more complex example that demonstrates how to build a RAG pipeline using
the new ray.data.llm APIs.

import ray
from ray.data.llm import (
    build_llm_processor,
    VLLMProcessorConfig,
    HTTPRequestProcessorConfig,
)

# Embedding generation
embed_processor_config = VLLMProcessorConfig(
    model="google-bert/bert-base-uncased",
    task_type="embed",
)

# Vector DB querying
retrieve_processor_config = HTTPRequestProcessorConfig(
    url="http://vector-db-endpoint",
    header="...",
    qps=1,
)

# LLM processing
llm_processor_config = VLLMProcessorConfig(
    model="meta-llama/Llama-3.1-70B-Instruct",
    engine_kwargs=dict(pipeline_parallel_size=2),
)

# Pipeline construction
ds = ray.data.read_parquet("...")

# Generate embeddings
ds = build_llm_processor(
    embed_processor_config,
    preprocess=lambda row: dict(prompt=row["question"]),
    postprocess=lambda row: dict(embedding=row["embedding"]),
    concurrency=2,
)(ds)

# Query vector DB
ds = build_llm_processor(
    retrieve_processor_config,
    preprocess=lambda row: dict(body=row["embedding"]),
    postprocess=lambda row: dict(retrieved=row["text"]),
    concurrency=4,
)(ds)

# Generate answers
ds = build_llm_processor(
    llm_processor_config,
    preprocess=lambda row: dict(
        messages=[
            {"role": "system", "content": "..."},
            {"role": "user", "content": f"{row['question']}\n{row['retrieved']}"},
        ],
        sampling_params=dict(temperature=0.3, max_tokens=250)
    ),
    concurrency=4,
)(ds)

Advanced Example: LoRA serving with ray.serve.llm

from ray import serve
from ray.serve.llm import VLLMDeployment, LLMConfig, ModelLoadingConfig, LoraConfig
from ray.serve import DeploymentConfig, AutoscalingConfig

# Configure the LLM with LoRA support
llm_config = LLMConfig(
    model_loading_config=ModelLoadingConfig(
        model_source="meta-llama/Llama-3.1-8b-instruct",
        served_model_name="llama-3.1-8b"
    ),
    lora_config=LoraConfig(
        # Path containing all LoRA adapters
        dynamic_lora_loading_path="s3://my-bucket/llama-loras/",
        # Maximum number of LoRA adapters that can share a single base model
        max_num_adapters_per_replica=4,
    ),
    deployment_config=DeploymentConfig(
        autoscaling_config=AutoscalingConfig(
            min_replicas=1,
            max_replicas=4
        )
    )
)

# Create and deploy the model
vllm_deployment = VLLMDeployment.options(**llm_config.get_serve_options()).bind(llm_config)
serve.run(vllm_deployment)

ray.serve.llm config-based API

There is also a configuration-based API for serving LLMs, where
the configurations can be declared separately from the application logic.

from pydantic import BaseModel, Field
from typing import List, Union
from ray.serve.llm import VLLMDeployment, LLMConfig

def build_vllm_deployment(llm_config: LLMConfig):
  return VLLMDeployment.options(llm_config.get_serve_options()).bind(llm_config)


class LLMServingArgs(BaseModel):
    llm_configs: List[Union[str, LLMConfig]] = Field(
        description="A list of LLMConfigs, or paths to LLMConfigs, to run.",
    )


def build_openai_app(llm_serving_args: LLMServingArgs):
    llm_configs = llm_serving_args.llm_configs
    llm_deployments = []
    for llm_config in llm_configs:
        if isinstance(llm_config, str):
            llm_config = LLMConfig.from_yaml(llm_config)
        llm_deployments.append(build_vllm_deployment(llm_config))
    return LLMModelRouterDeployment.bind(llm_deployments)

Sample config.yaml

# Demonstrate inline llm configs in the Serve config
application:
  name: llm_app
  route_prefix: "/"
  import_path: ray.serve.llm:build_openai_app
  args:
    llm_configs:
    - model_loading_config:
        model_id: meta-llama/Meta-Llama-3.1-8B-Instruct
      accelerator_type: A10G
      tensor_parallelism:
        degree: 1
      deployment_config:
        autoscaling_config:
          min_replicas: 1
          max_replicas: 2

Future Work

  1. Support for additional LLM inference engines beyond vLLM
  2. Enhanced monitoring and observability
  3. Advanced batching and scheduling optimizations
  4. Additional processor types for specialized workflows

cc @comaniac @kouroshHakha @akshay-anyscale @gvspraveen

@richardliaw richardliaw pinned this issue Feb 16, 2025
@lizzzcai
Copy link

This is a great RFC. I want to check if minReplicas: 0 is supported in the ray.serve example above and is there a plan to support Model Multiplexing for LLMs to serve N number of LLMs on M number of GPUs (N > M)? Model Multiplexing’s LRU-based caching and dynamic loading would allow efficient sharing of GPU resources while maintaining low latency for sparsely used models. Thanks.

@justinrmiller
Copy link
Contributor

This looks excellent. When development begins I’d like to lend a hand.

@kouroshHakha
Copy link
Contributor

This is a great RFC. I want to check if minReplicas: 0 is supported in the ray.serve example above and is there a plan to support Model Multiplexing for LLMs to serve N number of LLMs on M number of GPUs (N > M)? Model Multiplexing’s LRU-based caching and dynamic loading would allow efficient sharing of GPU resources while maintaining low latency for sparsely used models. Thanks.

@lizzzcai
Yes. for online inference you can do min_replica:0 and then leverage ray's autoscaling to serve the incoming dynamic traffic. That's the bonus of using ray-serve since it's an already existing functionality in the stack.

And for multiplexing LoRA adaptors that's indeed the plan. So you can have share the base model across all of them and only swap out the adaptor weights when a new request comes in. Further more, part of the scope is to have multiple base model support as well.

Say you have 8 GPUs. You can serve llama-3-8b on 1 GPU qwen-8b on 1 GPU, qwen-32B on 2GPUs with tp=2 and llama-70b on 4 GPUs with tp=4. Each base model can have multi-lora support so you can serve arbitrary number of their fine-tuned variants with the same resources.

@richardliaw
Copy link
Contributor Author

@lizzzcai @justinrmiller thanks a bunch for your comments -- please let us know if you have other feature requests / use cases we should design for.

@justinrmiller we've started development (for example #50494, #50270) -- and would love your help as there's tons to do. Maybe we can connect on Slack?

@Or-Levi
Copy link

Or-Levi commented Feb 18, 2025

Very cool @justinrmiller
Is this integrated with portkey?
I see you wrote this in future work: Support for additional LLM inference engines beyond vLLM
Will Gemini also be available?

@richardliaw
Copy link
Contributor Author

@Or-Levi - are you interested in online or offline?

Gemini should be supported via the HttpRequestProcessor, if you're looking for offline

richardliaw added a commit that referenced this issue Feb 19, 2025
## Why are these changes needed?

Adds user guide and link-ins for Ray Data documentation.

This is part of the #50639 thread of work.

This is based on #50494 

cc @comaniac @gvspraveen @kouroshHakha  

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Richard Liaw <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
@Or-Levi
Copy link

Or-Levi commented Feb 19, 2025

Yes, offline. Got it

400Ping pushed a commit to 400Ping/ray that referenced this issue Feb 20, 2025
…oject#50674)

## Why are these changes needed?

Adds user guide and link-ins for Ray Data documentation.

This is part of the ray-project#50639 thread of work.

This is based on ray-project#50494

cc @comaniac @gvspraveen @kouroshHakha

## Related issue number

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Richard Liaw <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Signed-off-by: 400Ping <[email protected]>
@richardliaw
Copy link
Contributor Author

Hi all, we've just merged an initial version of Ray Data LLM onto master.

Documentation is here: https://docs.ray.io/en/master/data/working-with-llms.html

You can try it out by installing nightly: https://docs.ray.io/en/master/ray-overview/installation.html#daily-releases-nightlies

Please try it out and let us know if you have any feature requests / run into any issues.

@thatcort
Copy link

Overall better vLLM integration would be great! A couple questions and some comments:
How are resources shared between the Ray Data and Ray Serve vLLM instances and between different Ray Data pipelines?
Are Ray Data vLLM calls done via REST or in-process?

GPUs are our main constraint and models are large and take some time to load. If we have a given model deployed on a (set of) GPU(s), then I'd like that vLLM instance to be maxed out before the cluster tries to allocate more GPUs. Currently we run vLLM externally to Ray and make all calls via REST api. We've also largely given up on GPU scaling due to cloud provisioning problems, so autoscaling is still more wish than reality.

@comaniac
Copy link
Collaborator

Overall better vLLM integration would be great! A couple questions and some comments: How are resources shared between the Ray Data and Ray Serve vLLM instances and between different Ray Data pipelines? Are Ray Data vLLM calls done via REST or in-process?

We implemented vLLM engine as an UDF of Ray Data .map_batches(), so 1) the resources that vLLM can use are allocated by Ray Data; 2) it's not launching an API server and use REST. Ray Data directly creates a vLLM engine object and use it in place. Pseudo code:

class UDF:
    def __init__(self, ...):
        self.llm = vllm.engine(...)

    async def __call__(self, batch):
        return await sefl.llm(batch)

dataset = dataset.map_batches(
    UDF,
    num_gpus=1, # One GPU per engine.
    concurrency=N, # Launch N engines.
)

GPUs are our main constraint and models are large and take some time to load. If we have a given model deployed on a (set of) GPU(s), then I'd like that vLLM instance to be maxed out before the cluster tries to allocate more GPUs. Currently we run vLLM externally to Ray and make all calls via REST api. We've also largely given up on GPU scaling due to cloud provisioning problems, so autoscaling is still more wish than reality.

We are able to handle that (not landed at this moment) in the following way:

dataset = dataset.map_batches(
    UDF,
    num_gpus=1, # One GPU per engine.
    concurrency=(1, N) # Launch N engines, but start processing once there's one engine ready.
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants