Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, *args, **kwargs):

# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
Expand Down
162 changes: 162 additions & 0 deletions examples/offline_inference/rlhf_online_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.

The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 1–2.

The example performs the following steps:

* Load the training model on GPU 0.
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.

For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF

This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""

import json
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerRow,
)
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
from vllm.utils.network_utils import get_ip, get_open_port


class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""

def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs)


# Load the OPT-125M model onto GPU 0 for the training workload.
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")

# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)

# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.

# generate torchao quantization config for RL rollout
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
# use serialized config files instead of passing around json string
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())

json_str = json.dumps(config_to_dict(config))

llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model="facebook/opt-125m",
hf_overrides={"quantization_config_dict_json": json_str},
enforce_eager=True,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)

# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = SamplingParams(temperature=0)

outputs = ray.get(llm.generate.remote(prompts, sampling_params))

print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

# Set up the communication channel between the training process and the
# inference engine.
master_address = get_ip()
master_port = get_open_port()

handle = llm.collective_rpc.remote(
"init_weight_update_group", args=(master_address, master_port, 1, 3)
)

model_update_group = stateless_init_process_group(
master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle)

# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters():
p.data.zero_()

# Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters():
dtype_name = str(p.dtype).split(".")[-1]
handle = llm.collective_rpc.remote(
"update_weight", args=(name, dtype_name, p.shape)
)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)

# Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))

# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
46 changes: 11 additions & 35 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_quant_config,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
Expand Down Expand Up @@ -273,42 +274,17 @@ def download_model(self, model_config: ModelConfig) -> None:
)

def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}

# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.14.0")
):
self.load_config.safetensors_load_strategy = "torchao"

if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize,
)

# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))

self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
Expand Down
Loading