Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ cd PipelineRL

Create the environments with dependencies.
```bash
conda create -n pipeline-rl -y python=3.11
conda run --no-capture-output -n pipeline-rl pip install torch==2.6.0
conda run --no-capture-output -n pipeline-rl pip install -e . --no-build-isolation
conda create -n pipeline-rl -y python=3.12
conda run --no-capture-output -n pipeline-rl pip install -e .
conda run --no-capture-output -n pipeline-rl pip install flash-attn==2.8.3 --no-build-isolation
```

By default Pipeline-RL will use the file system as the medium for streaming the generated data to the trainer processes. This works on one node, but the files can get quite large. To use Redis instead you will need to install the Redis server in the same conda environment:
Expand Down
11 changes: 7 additions & 4 deletions conf/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,11 @@ test_llm:
top_k: 50

vllm_config:
use_v1: false
use_v1: true
quantization: null # or bf16_last_layer_fp32
vllm_kwargs:
dtype: bfloat16
gpu-memory-utilization: 0.9
num-scheduler-steps: 1
disable-log-requests: ""
disable-frontend-multiprocessing: ""
max-num-seqs: ${actor.llm_max_rollouts}
max-num-batched-tokens: 1024
enable-chunked-prefill: ""
Expand All @@ -73,6 +70,12 @@ vllm_config:
pipeline-parallel-size: 1
generation-config: vllm
max_model_len: 10000
# V1 specific settings
# logprobs-mode: processed_logprobs
# V0 specific settings
# num-scheduler-steps: 1
# disable-log-requests: ""
# disable-frontend-multiprocessing: ""

world:
replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion pipelinerl/finetune/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def load_model(args, model_class, current_dir):
is_ds_zero_3 = get_accelerator().state.deepspeed_plugin.zero_stage == 3 # type: ignore

if args.load_as_bf16:
loading_args["torch_dtype"] = torch.bfloat16
loading_args["dtype"] = torch.bfloat16
if args.auto_device_map:
loading_args["device_map"] = "auto"
model_cls = get_auto_model_class(model_class)
Expand Down
2 changes: 1 addition & 1 deletion pipelinerl/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def merge_lora(lora_model_path):
assert os.path.exists(lora_model_config), f"{lora_model_config} does not exists"

logger.info(f"Merge lora checkpoint {lora_model_path}")
model = lora_load_and_merge(lora_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
model = lora_load_and_merge(lora_model_path, dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(lora_model_path)

tmp_dir = f"{lora_model_path}_merged"
Expand Down
23 changes: 13 additions & 10 deletions pipelinerl/finetune_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params

from pipelinerl.finetune.value_model import AutoModelForCausalLMWithValueHead
import pipelinerl.torch_utils
from pipelinerl import torch_utils
from pipelinerl.finetune.types import PipelineBatchEncoding

from pipelinerl.finetune.checkpoints import (
load_model,
load_tokenizer,
Expand Down Expand Up @@ -212,7 +213,8 @@ def send_weight_update(
for name, parameter in named_parameters.items():
with deepspeed.zero.GatheredParameters([parameter]):
if get_accelerator().is_main_process:
dist.broadcast(parameter.data, src=0, group=self.actor_update_group)
# Use PyNcclCommunicator's broadcast method instead of torch.distributed
self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream())
if get_accelerator().is_main_process:
logger.info("Wait for HTTP requests")
for future in futures: # type: ignore
Expand Down Expand Up @@ -254,8 +256,8 @@ def send_weight_update(
futures = self.request_weight_updates(messages)
logger.info(f"Published weight update request for version {version}")
for _, parameter in named_parameters.items():
dist.broadcast(parameter.data, src=0, group=self.actor_update_group)
dist.barrier(self.actor_update_group)
# Use PyNcclCommunicator's broadcast method instead of torch.distributed
self.actor_update_group.broadcast(parameter.data, src=0, stream=torch.cuda.current_stream())
for future in futures:
future.result()
logger.info("Finished broadcasting weights")
Expand Down Expand Up @@ -408,13 +410,15 @@ def run_finetuning_loop(
get_accelerator().wait_for_everyone()

if get_accelerator().is_main_process and args.send_weight_updates:
logger.info("Initializing actor process group")
actor_update_group = pipelinerl.torch_utils.init_extra_process_group(
group_name="actor",
backend="nccl",
current_device = get_accelerator().device
torch.cuda.set_device(current_device)
logger.info("Initializing actor process group using StatelessProcessGroup")
logger.info(f"Set CUDA device to {current_device} for actor process group (rank 0)")
actor_update_group = torch_utils.stateless_init_process_group(
init_method=cfg.me.weight_update_group_init_method,
rank=0,
world_size=cfg.me.weight_update_group_world_size,
device=current_device,
)
logger.info("Actor process group initialized")
else:
Expand Down Expand Up @@ -493,8 +497,7 @@ def run_finetuning_loop(
finally:
if weight_update_manager is not None:
weight_update_manager.shutdown()
if actor_update_group:
dist.destroy_process_group(actor_update_group)
# PyNcclCommunicator doesn't need explicit destroy like torch.distributed process groups


def rl_finetuning_worker(
Expand Down
45 changes: 45 additions & 0 deletions pipelinerl/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import logging
from datetime import timedelta
from typing import Any, Optional, Union
from urllib.parse import urlparse

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
ProcessGroupNCCL,
Store,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

logger = logging.getLogger(__name__)


def stateless_init_process_group(init_method, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.

Args:
init_method: TCP init method string (e.g., "tcp://localhost:9000")
rank: The rank of this process in the group
world_size: Total number of processes in the group
device: The CUDA device to use for NCCL communication
"""
# Parse master_address and master_port from init_method (e.g., "tcp://localhost:9000")
parsed = urlparse(init_method)
master_address = parsed.hostname or "localhost"
master_port = parsed.port or 9000
logger.debug(f"Parsed master_address: {master_address}, master_port: {master_port}")

pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


# Copy from pytorch to allow creating multiple main groups.
Expand Down Expand Up @@ -49,6 +86,13 @@ def init_extra_process_group(
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)

# Create NCCL-specific options if using NCCL backend
logger.info(f"[{group_name}] Backend: {backend}, str(backend): {str(backend)}")
if pg_options is None and str(backend) == "nccl":
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = False
logger.info(f"[{group_name}] Created NCCL options: {pg_options}")

pg, _ = _new_process_group_helper(
world_size,
rank,
Expand All @@ -59,6 +103,7 @@ def init_extra_process_group(
backend_options=pg_options,
timeout=timeout,
)
logger.info(f"[{group_name}] Process group created successfully")

_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

Expand Down
34 changes: 33 additions & 1 deletion pipelinerl/vllm0.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,36 @@
"""
DEPRECATED - Kept only for backward compatibility with older vLLM versions.

This module provides a custom vLLM inference server with dynamic weight updates using the legacy V0 engine architecture.

Compatibility:
- vLLM versions <= 0.8.x only
- The V0 engine was removed in vLLM 0.11.0
- Use vllm1.py instead
"""
import warnings
from packaging import version as version_parser
import vllm

# Check vLLM version compatibility
vllm_version = version_parser.parse(vllm.__version__)

if vllm_version >= version_parser.parse("0.9.0"):
raise ImportError(
f"pipelinerl.vllm0 is not compatible with vLLM {vllm.__version__}. "
"This module only works with vLLM <= 0.8.x. "
"Please use pipelinerl.vllm1 for vLLM >= 0.11.0 instead."
)

# Only show deprecation warning for compatible versions
warnings.warn(
"pipelinerl.vllm0 is DEPRECATED and will be removed in a future version. "
"This module only works with vLLM <= 0.8.x. "
"Please use pipelinerl.vllm1 as it is actively maintained.",
DeprecationWarning,
stacklevel=2,
)

import asyncio
import json
import logging
Expand All @@ -14,7 +47,6 @@
)
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
run_server,
create_server_socket,
build_app,
init_app_state,
Expand Down
59 changes: 40 additions & 19 deletions pipelinerl/vllm1.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import logging
import signal
import torch
import uvloop
from vllm.utils import FlexibleArgumentParser, set_ulimit
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.system_utils import set_ulimit
from vllm.entrypoints.openai.cli_args import (
make_arg_parser,
validate_parsed_serve_args,
)
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
run_server,
create_server_socket,
build_app,
init_app_state,
Expand All @@ -26,8 +27,8 @@

from pipelinerl.finetune_loop import WeightUpdateRequest
from pipelinerl.vllm_quantization import string_to_dtype # reuse mapping
from pipelinerl.torch_utils import stateless_init_process_group
from typing import Any, Protocol, runtime_checkable
import pipelinerl.torch_utils
import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config

logger = logging.getLogger(__name__)
Expand All @@ -46,7 +47,7 @@ class LikeWorker(Protocol):
rank: int
local_rank: int
device: torch.device
model_runner: GPUModelRunner
model_runner: GPUModelRunner
pg_rank: int
process_group: Any
model_config: ModelConfig
Expand All @@ -72,34 +73,47 @@ def init_actor_update_group(
prefix
+ f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}"
)
self.process_group = pipelinerl.torch_utils.init_extra_process_group(
group_name="actor",
backend="nccl",

# Use vLLM's StatelessProcessGroup instead of torch.distributed
self.model_update_group = stateless_init_process_group(
init_method=weight_update_group_init_method,
rank=self.pg_rank,
world_size=weight_update_group_world_size,
device=self.device,
)
logger.info(prefix + "Actor update process group initialized")

def receive_weight_update(self: LikeWorker, request: WeightUpdateRequest):
def receive_weight_update(self: LikeWorker, request_json: str):
request = WeightUpdateRequest.model_validate_json(request_json)
torch.cuda.synchronize(self.device)
logger.info("Start receiving weight update")
expected_dtypes = (torch.bfloat16, torch.float32, torch.float16)

for info in request.parameters_info:
target_dtype = string_to_dtype(info.dtype)
if target_dtype not in expected_dtypes:
logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}")
buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device)
torch.distributed.broadcast(buffer, src=0, group=self.process_group)
loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore
self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream())
loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore
if len(loaded_params) != 1:
raise ValueError(f"model {info.name} not found in model state dict")

pipelinerl.vllm_quantization.invalidate_fp32_cache()
logger.info("Weight update received")

def close_communicator(self):
"""Closes the communicator when weight synchronization is no longer needed."""
if hasattr(self, "model_update_group") and self.model_update_group is not None:
del self.model_update_group
self.model_update_group = None
logger.info("Weight update communicator closed")


class WeightUpdateManager:
def __init__(self, args, engine_client: AsyncMPClient):
def __init__(self, args, engine: AsyncLLM, engine_client: AsyncMPClient):
self.args = args
self.engine = engine
self.engine_client = engine_client

async def input_process_groups(self):
Expand All @@ -114,11 +128,16 @@ async def input_process_groups(self):
)

async def receive_weight_update(self, request: WeightUpdateRequest):
logger.info("Starting weight update...")
await self.engine_client.collective_rpc_async(
"receive_weight_update", args=(request,)
"receive_weight_update", args=(request.model_dump_json(),)
)
logger.info("Weight update processed")

async def close_communicator(self):
"""Closes the communicator when weight synchronization is no longer needed."""
await self.engine_client.collective_rpc_async("close_communicator")


async def run_server(args, **uvicorn_kwargs) -> None:
# COPIED FROM vllm/entrypoints/openai/api_server.py, vllm version 0.6.6.post1
Expand Down Expand Up @@ -157,11 +176,11 @@ def signal_handler(*_) -> None:
vllm_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER,
disable_log_stats=engine_args.disable_log_stats,
disable_log_requests=engine_args.disable_log_requests,
enable_log_requests=engine_args.enable_log_requests,
)
assert isinstance(engine.engine_core, AsyncMPClient)

weight_update_manager = WeightUpdateManager(args, engine.engine_core)
weight_update_manager = WeightUpdateManager(args, engine, engine.engine_core)
if not args.disable_weight_updates:
await weight_update_manager.input_process_groups()

Expand All @@ -172,10 +191,12 @@ def signal_handler(*_) -> None:

@app.post("/receive_weight_update")
async def _receive_weight_update(request: WeightUpdateRequest):
await weight_update_manager.receive_weight_update(request)
# Fire-and-forget: return immediately, weight update happens in background
logger.info("Received weight update request (fire-and-forget)")
asyncio.create_task(weight_update_manager.receive_weight_update(request))
return {"status": "ok"}

await init_app_state(engine, engine_config, app.state, args)
await init_app_state(engine, app.state, args)
shutdown_task = await serve_http(
app,
sock,
Expand All @@ -194,11 +215,11 @@ async def _receive_weight_update(request: WeightUpdateRequest):
# NB: Await server shutdown only after the backend context is exited
await shutdown_task

# Cleanup
if not args.disable_weight_updates:
await weight_update_manager.close_communicator()
sock.close()

# TODO: proper cleanup
# dist.destroy_process_group(actor_update_group)


def run_llm():
parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.")
Expand Down
Loading