diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 303d9ed9045..0563f7d1462 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -48,6 +48,7 @@ def init_predefined_dispatch_mode(): Dispatch.register("DP_COMPUTE_PROTO") Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC") Dispatch.register("DP_COMPUTE_METRIC") + Dispatch.register("DP_DISPATCH") # This is a special dispatch mode for vllm ExternalRayDistributedExecutor Dispatch.register("DIRECT_ROLLOUT_METHOD") @@ -135,6 +136,12 @@ def dispatch_all_to_all(worker_group, *args, **kwargs): def collect_all_to_all(worker_group, output): return output +def dispatch_dp(worker_group, *args, **kwargs): + return args, kwargs + +def collect_dp(worker_group, output): + return output + def dispatch_megatron_compute(worker_group, *args, **kwargs): """ @@ -415,6 +422,7 @@ def collect_dp_compute_data_proto(worker_group, output): "dispatch_fn": dummy_direct_rollout_call, "collect_fn": dummy_direct_rollout_call, }, + Dispatch.DP_DISPATCH: {"dispatch_fn": dispatch_dp, "collect_fn": collect_dp}, } diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index fc3af80d4f5..21a1776c1d2 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -13,6 +13,11 @@ top_k: -1 # Top-p sampling parameter. Default 1.0. top_p: 1 +# over sampling batch size +over_sampling_batch_size: 0 +rollout_batch_size: 0 +partial_rollout: false + # typically the same as data max prompt length # same as data.max_prompt_length if it exists prompt_length: ${oc.select:data.max_prompt_length,512} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index c5b76310546..f9d349c6cd7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -60,6 +60,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.rollout.rollout_manager import RolloutManager WorkerType = type[Worker] @@ -732,22 +733,8 @@ def _validate(self): } print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") - # pad to be divisible by dp_size - size_divisor = ( - self.actor_rollout_wg.world_size - if not self.async_rollout_mode - else self.config.actor_rollout_ref.rollout.agent.num_workers - ) - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) - if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - else: - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) - - # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + test_output_gen_batch = self.generate_sequences(test_gen_batch) - print("validation generation end") # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] @@ -830,6 +817,10 @@ def init_workers(self): self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + self.rollout_manager = RolloutManager.options(num_cpus=1, num_gpus=0).remote(config=self.config) + self.sglang_router_ip, self.sglang_router_port = ray.get( + self.rollout_manager.get_sglang_router_ip_and_port.remote() + ) # create actor and rollout if self.hybrid_engine: @@ -839,6 +830,8 @@ def init_workers(self): config=self.config.actor_rollout_ref, role="actor_rollout", profile_option=self.config.trainer.npu_profile.options, + sglang_router_ip=self.sglang_router_ip, + sglang_router_port=self.sglang_router_port, ) self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls else: @@ -1078,6 +1071,19 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle ) metrics.update(global_balance_stats) + def rollout(self) -> DataProto: + self.actor_rollout_wg.prepare_for_generate() + batch = ray.get(self.rollout_manager.rollout.remote()) + self.actor_rollout_wg.finish_generate() + return batch + + def generate_sequences(self, batch: DataProto) -> DataProto: + # For compatibility with verl's original generate_sequences + self.actor_rollout_wg.prepare_for_generate() + batch = ray.get(self.rollout_manager.generate_sequences.remote(batch)) + self.actor_rollout_wg.finish_generate() + return batch + def fit(self): """ The training loop of PPO. @@ -1103,13 +1109,13 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return + # if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + # val_metrics = self._validate() + # assert val_metrics, f"{val_metrics=}" + # pprint(f"Initial validation metrics: {val_metrics}") + # logger.log(data=val_metrics, step=self.global_steps) + # if self.config.trainer.get("val_only", False): + # return # add tqdm progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") @@ -1118,9 +1124,10 @@ def fit(self): self.global_steps += 1 last_val_metrics = None self.max_steps_duration = 0 + per_epoch_iters = ray.get(self.rollout_manager.get_num_rollout_per_epoch.remote()) for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: + for batch_dict in range(per_epoch_iters): metrics = {} timing_raw = {} @@ -1132,45 +1139,14 @@ def fit(self): with marked_timer("start_profile", timing_raw): self._start_profiling(do_profile) - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "index" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("index") - if "agent_name" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - # pass global_steps to trace - gen_batch.meta_info["global_steps"] = self.global_steps - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - + batch = self.rollout() + gen_batch = batch if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) @@ -1189,13 +1165,6 @@ def fit(self): del gen_baseline_batch, gen_baseline_output - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - if "response_mask" not in batch.batch.keys(): batch.batch["response_mask"] = compute_response_mask(batch) # Balance the number of valid tokens across DP ranks. diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 475b0a51783..74601d14015 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -87,7 +87,12 @@ class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config: DictConfig, role: str, **kwargs): MegatronWorker.__init__(self) + from loguru import logger as log + self.config = config + self.sglang_router_ip = kwargs["sglang_router_ip"] + self.sglang_router_port = kwargs["sglang_router_port"] + log.info(f"sglang_router_host: {self.sglang_router_ip}, sglang_router_port: {self.sglang_router_port}") # NOTE(sgm): We utilize colocate WorkerGroup by default. # As a result, Workers for different model share the same process. @@ -110,7 +115,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, + # pipeline_model_parallel_split_rank=None, use_sharp=False, context_parallel_size=self.config.actor.megatron.context_parallel_size, expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, @@ -358,6 +363,8 @@ def _build_rollout(self, trust_remote_code=False): model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, device_mesh=rollout_device_mesh, + sglang_router_ip=self.sglang_router_ip, + sglang_router_port=self.sglang_router_port, ) log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) @@ -576,6 +583,19 @@ def generate_sequences(self, prompts: DataProto): get_torch_device().empty_cache() return output + @register(dispatch_mode=Dispatch.DP_DISPATCH) + @GPUMemoryLogger(role="prepare_for_generate", logger=logger) + @DistProfiler.annotate(color="olive") + def prepare_for_generate(self): + self.sharding_manager.prepare_for_generate() + + @register(dispatch_mode=Dispatch.DP_DISPATCH) + @GPUMemoryLogger(role="finish_generate", logger=logger) + @DistProfiler.annotate(color="olive") + def finish_generate(self): + self.sharding_manager.finish_generate() + get_torch_device().empty_cache() + @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) @DistProfiler.annotate(color="olive") diff --git a/verl/workers/rollout/buffer.py b/verl/workers/rollout/buffer.py new file mode 100644 index 00000000000..116f246d588 --- /dev/null +++ b/verl/workers/rollout/buffer.py @@ -0,0 +1,134 @@ +import copy +import os +import uuid +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +import ray +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.dataset.rl_dataset import RLHFDataset + + +class Status(Enum): + PENDING = "pending" + COMPLETED = "completed" + TRUNCATED = "truncated" + ABORTED = "aborted" + + +class Buffer: + def __init__(self, config): + # init_wandb_secondary(args, wandb_run_id) + self.config = config + + # 数据源相关属性 + self.epoch_id = 0 + self.sample_index = 0 + self.sample_offset = 0 + # TODO remove this + self.metadata = {} + + # 初始化tokenizer和processor + local_path = self.config.actor_rollout_ref.model.path + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + # Used for multimodal LLM, could be None + self.processor = hf_processor(local_path, trust_remote_code=True, use_fast=True) + + # 加载RLHF数据集 + rldataset = RLHFDataset( + data_files=self.config.data.train_files, + tokenizer=self.tokenizer, + processor=self.processor, + config=self.config.data, + ) + self.dataset = [] + for item in tqdm(rldataset, desc="Loading RLHF dataset", total=len(rldataset)): + self.dataset.append(item) + + self.n_samples_per_prompt = self.config.actor_rollout_ref.rollout.n + + self.buffer: List[List[Dict]] = [] + + def get_num_rollout_per_epoch(self): + return len(self.dataset) //self.config.actor_rollout_ref.rollout.rollout_batch_size + + def _get_samples_from_data_source(self, num_samples: int) -> List[List[Dict]]: + """从数据源获取样本,整合了原RolloutDataSource.get_samples的逻辑""" + samples = [] + # TODO unify the two branches + if self.sample_offset + num_samples <= len(self.dataset): + prompt_samples = self.dataset[self.sample_offset : self.sample_offset + num_samples] + self.sample_offset += num_samples + else: + prompt_samples = self.dataset[self.sample_offset :] + num_samples -= len(prompt_samples) + # self.epoch_id += 1 + # if self.args.rollout_shuffle: + # self.dataset.shuffle(self.epoch_id) + prompt_samples += self.dataset[:num_samples] + self.sample_offset = num_samples + # self.sample_offset = 0 + + + for prompt_sample in prompt_samples: + group = [] + prompt_sample["status"] = Status.PENDING + prompt_sample["response_length"] = 0 + prompt_sample["response"] = [] + prompt_sample["uid"] = str(uuid.uuid4()) + + for _ in range(self.n_samples_per_prompt): + sample = copy.deepcopy(prompt_sample) + group.append(sample) + samples.append(group) + + return samples + + # TODO simplify remaining logic + def get_samples(self, num_samples: int) -> List[List[Dict]]: + """ + Return num_samples samples + """ + + samples = self._get_samples_from_buffer(num_samples) + num_samples -= len(samples) + + if num_samples == 0: + return samples + + samples += self._get_samples_from_data_source(num_samples=num_samples) + return samples + + def _get_samples_from_buffer(self, num_samples: int) -> List[List[Dict]]: + if len(self.buffer) == 0 or num_samples == 0: + return [] + num_to_pop = min(len(self.buffer), num_samples) + samples = self.buffer[:num_to_pop] + del self.buffer[:num_to_pop] + return samples + + def add_samples(self, samples: List[List[Dict]]): + """ + Add a sample group to buffer. + """ + if not samples: + return + assert isinstance(samples, list), f"samples must be a list, got {type(samples)}" + assert isinstance(samples[0], list), f"the elements of samples must be list, got {type(samples[0])}" + for i in range(0, len(samples)): + assert len(samples[i]) == self.n_samples_per_prompt, ( + f"the length of the elements of samples must be equal to n_samples_per_prompt, got {len(samples[i])} != {self.n_samples_per_prompt}" + ) + group = samples[i] # type: ignore + self.buffer.append(group) + + def get_buffer_length(self): + return len(self.buffer) diff --git a/verl/workers/rollout/http_utils.py b/verl/workers/rollout/http_utils.py new file mode 100644 index 00000000000..97c10046362 --- /dev/null +++ b/verl/workers/rollout/http_utils.py @@ -0,0 +1,106 @@ +import asyncio +import multiprocessing +import random +import socket + +import httpx + + +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except socket.error: + return False + except OverflowError: + return False + + +def get_host_info(): + hostname = socket.gethostname() + + local_ip = socket.gethostbyname(hostname) + + return hostname, local_ip + + +def run_router(args): + try: + from sglang_router.launch_router import launch_router + + router = launch_router(args) + if router is None: + return 1 + return 0 + except Exception as e: + print(e) + return 1 + + +def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None: + """Terminate a process gracefully, with forced kill as fallback. + + Args: + process: The process to terminate + timeout: Seconds to wait for graceful termination before forcing kill + """ + if not process.is_alive(): + return + + process.terminate() + process.join(timeout=timeout) + if process.is_alive(): + process.kill() + process.join() + + +async def post(url, payload, use_http2=False, max_retries=60): + # never timeout + timeout = httpx.Timeout(None) + max_retries = 60 + retry_count = 0 + while retry_count < max_retries: + try: + async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client: + response = await client.post(url, json=payload or {}) + response.raise_for_status() + try: + output = response.json() + except: + output = response.text + except Exception as e: + retry_count += 1 + print(f"Error: {e}, retrying... (attempt {retry_count}/{max_retries})") + if retry_count >= max_retries: + print(f"Max retries ({max_retries}) reached, failing...") + raise e + await asyncio.sleep(1) + continue + break + + return output + + +async def get(url, use_http2=False): + # never timeout + timeout = httpx.Timeout(None) + async with httpx.AsyncClient(http1=not use_http2, http2=use_http2, timeout=timeout) as client: + response = await client.get(url) + response.raise_for_status() + output = response.json() + return output diff --git a/verl/workers/rollout/rollout_manager.py b/verl/workers/rollout/rollout_manager.py new file mode 100644 index 00000000000..ca03325dad6 --- /dev/null +++ b/verl/workers/rollout/rollout_manager.py @@ -0,0 +1,523 @@ +# This file is adapted from multiple sources: +# THUDM/slime project +# Original source: https://github.com/THUDM/slime/blob/main/slime/rollout/sglang_rollout.py +# Original source: https://github.com/THUDM/slime/blob/main/slime/ray/rollout_data_source.py +# Original source: https://github.com/THUDM/slime/blob/main/slime/ray/rollout.py +# Copyright 2025 Zhipu AI +# Licensed under the Apache License, Version 2.0 + +import asyncio +import threading +from collections import defaultdict +from typing import Dict, Union + +import numpy as np +import ray +import torch +from sglang.srt.sampling.sampling_params import SamplingParams +from tensordict import TensorDict +from torch.nn.utils.rnn import pad_sequence +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.workers.rollout.buffer import Buffer, Status +from verl.workers.rollout.http_utils import get, post +from verl.workers.rollout.sglang_rollout.sglang_rollout import _post_process_outputs, _pre_process_inputs, _start_router + + +# from slime.utils.types import Dict +class AsyncLoopThread: + def __init__(self): + self.loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + def _start_loop(self): + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + def run(self, coro): + # Schedule a coroutine onto the loop and block until it's done + return asyncio.run_coroutine_threadsafe(coro, self.loop).result() + + +# Create one global instance +async_loop = None + + +def get_async_loop(): + global async_loop + if async_loop is None: + async_loop = AsyncLoopThread() + return async_loop + + +def run(coro): + """Run a coroutine in the background event loop.""" + return get_async_loop().run(coro) + + +@ray.remote +class RolloutManager: + def __init__(self, config): + self.config = config.actor_rollout_ref.rollout + self.debug = True + self.debug =False + if self.debug: + # self.sglang_router_ip, self.sglang_router_port = _start_router() + self.sglang_router_ip = "127.0.0.1" + self.sglang_router_port = 30000 + else: + self.sglang_router_ip, self.sglang_router_port = _start_router() + + self.data_buffer = Buffer(config) + self.partial_rollout = config.actor_rollout_ref.rollout.partial_rollout + self.use_http2 = False + self.n_samples_per_prompt = config.actor_rollout_ref.rollout.n + self.over_sampling_batch_size = config.actor_rollout_ref.rollout.over_sampling_batch_size + self.rollout_batch_size = config.actor_rollout_ref.rollout.rollout_batch_size + local_path = config.actor_rollout_ref.model.path + + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + self.processor = hf_processor(local_path, trust_remote_code=True, use_fast=True) + self.init_sampling_params() + + # Initialize missing attributes + self.semaphore = asyncio.Semaphore(100) # Limit concurrent requests + self.group_rm = False # Default group reward model setting + self.use_token_output = True # Default token output setting + self.reset() + + def init_sampling_params(self): + self.sampling_params = dict( + max_new_tokens=self.config.response_length, + temperature=1.0, + top_p=1.0, + top_k=-1, + ) + + def get_num_rollout_per_epoch(self): + return self.data_buffer.get_num_rollout_per_epoch() + + def get_sglang_router_ip_and_port(self): + return self.sglang_router_ip, self.sglang_router_port + + def reset(self): + self.remaining_batch_size = 0 + self.pendings = set() + self.aborted = False + + def submit_generate_tasks(self, samples_group: list[list[Dict]]): + for group in samples_group: + self.pendings.add( + asyncio.create_task( + # submit a group of samples as a single task. + self.generate_group( + group, + evaluation=False, + ) + ) + ) + self.remaining_batch_size += len(samples_group) + + def rollout(self): + rollout_result, aborted_samples = run(self.rollout_async()) + self.data_buffer.add_samples(aborted_samples) + return self._convert_samples_to_data_proto(rollout_result) + + async def rollout_async(self) -> list[list[Dict]]: + data = [] + pbar = tqdm(total=self.rollout_batch_size * self.n_samples_per_prompt, desc="Rollout generation") + + while len(data) < self.rollout_batch_size: + while self.remaining_batch_size < self.rollout_batch_size: + # 从buffer中获取样本并提交生成请求 + samples = self.data_buffer.get_samples(self.over_sampling_batch_size) + + self.submit_generate_tasks(samples) + + # 等待生成完成 + done, self.pendings = await asyncio.wait(self.pendings, return_when=asyncio.FIRST_COMPLETED) + for task in done: + group: list[Dict] = task.result() + + assert len(group) == self.n_samples_per_prompt + # 过滤器相关逻辑去除 + + # 添加样本到data + if len(data) < self.rollout_batch_size: + data.append(group) + pbar.update(self.n_samples_per_prompt) + + pbar.close() + print(f"Rollout generation finished, got {len(data)} samples", flush=True) + + aborted_samples = await self.abort() + # aborted_samples = [] + print(f"Aborted {len(aborted_samples)} samples", flush=True) + + assert len(data) == self.rollout_batch_size, f"Got {len(data)} samples, expected {self.rollout_batch_size}" + from loguru import logger as log + + log.info(f"Rollout generation finished, got {len(data)} samples") + # data = sorted(data, key=lambda group: group[0].index) + + # 重置全局状态 + self.reset() + return data, aborted_samples + + def _convert_samples_to_data_proto(self, samples: list[list[Dict]]): + """ + Convert inference generated samples to training data. + """ + output_req_list = sum(samples, []) + + idx = [] + position_ids = [] + + # 批量收集所有请求中的张量 + idx = torch.cat([req["input_ids"].unsqueeze(0) for req in output_req_list], dim=0) + attention_mask = torch.cat([req["attention_mask"].unsqueeze(0) for req in output_req_list], dim=0) + position_ids = torch.cat([req["position_ids"].unsqueeze(0) for req in output_req_list], dim=0) + response = pad_sequence( + [torch.tensor(req["response"]) for req in output_req_list], + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + + if response.shape[-1] < self.config.response_length: + response = pad_sequence_to_length(response, self.config.response_length, self.tokenizer.pad_token_id) + # import pdb; pdb.set_trace() + + batch_size = idx.size(0) + seq = torch.cat([idx, response], dim=-1) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + if position_ids.dim() == 3: # qwen2vl mrope + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + + # TODO(sgm): fix position_ids on right_pad + # prompt: left pad + response: right pad + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + response_position_ids = position_ids[..., -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + response_attention_mask = get_response_mask( + response_id=response, eos_token=self.tokenizer.eos_token_id, dtype=attention_mask.dtype + ) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + # uids = np.array([req["uid"] for req in output_req_list], dtype=object) + + # Construct the batch data + batch = TensorDict( + { + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=len(output_req_list), + ) + + non_tensors = defaultdict(list) + for req in output_req_list: + for key, value in req.items(): + if isinstance(value, torch.Tensor): + pass + else: + non_tensors[key].append(value) + for key, val in non_tensors.items(): + non_tensors[key] = np.array(val, dtype=object) + + return DataProto( + batch=batch, + non_tensor_batch=non_tensors, + ) + + async def abort( + self, + ): + aborted_samples = [] + + assert not self.aborted + self.aborted = True + response = await get( + f"http://{self.sglang_router_ip}:{self.sglang_router_port}/list_workers", use_http2=self.use_http2 + ) + + # abort all the requests + for url in response["urls"]: + print(f"Abort request for {url}", flush=True) + await post(f"{url}/abort_request", {"abort_all": True}, use_http2=False) + + # make sure all the pending tasks are finished + count = 0 + while self.pendings: + done, self.pendings = await asyncio.wait(self.pendings, return_when=asyncio.FIRST_COMPLETED) + + if not self.partial_rollout: + continue + + # for partial rollout, collect the partial samples into the data buffer + for task in done: + group = task.result() + + # for sample in group: + # if sample.response and "start_rollout_id" not in sample.metadata: + # sample.metadata["start_rollout_id"] = rollout_id + aborted_samples += [group] + count += len(group) + + if self.partial_rollout: + print(f"Collected {count} partial samples into the data buffer", flush=True) + # import pdb; pdb.set_trace() + + return aborted_samples + + async def generate(self, sample: dict) -> Dict: + url = f"http://{self.sglang_router_ip}:{self.sglang_router_port}/generate" + + assert sample["status"] == Status.PENDING or sample["status"] == Status.ABORTED, ( + f"Dict status is {sample['status']}" + ) + from loguru import logger as log + + new_sampling_params = self.sampling_params.copy() + + if len(sample["response"]) > 0: + new_sampling_params["max_new_tokens"] -= len(sample["response"]) + + assert new_sampling_params["max_new_tokens"] >= 0, ( + f"max_new_tokens: {new_sampling_params['max_new_tokens']} should not be less than 0" + ) + if new_sampling_params["max_new_tokens"] == 0: + sample["status"] = Status.TRUNCATED + return sample + + # Prepare payload - shared structure + payload = { + "sampling_params": new_sampling_params, + "return_logprob": self.use_token_output, + } + + if self.use_token_output: + if len(sample["response"]) > 0: + input_token_ids = sample["raw_prompt_ids"] + sample["response"] + else: + input_token_ids = sample["raw_prompt_ids"] + payload["input_ids"] = input_token_ids + else: + # String-based mode: original implementation + input_text = sample["raw_prompt_ids"] + sample["response"] + payload["text"] = input_text + + output = await post(url, payload, use_http2=self.use_http2) + + if "output_token_logprobs" in output["meta_info"]: + new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]] + # Update sample with tokens directly + sample["response_length"] += len(new_response_tokens) + sample["response"] += new_response_tokens + + match output["meta_info"]["finish_reason"]["type"]: + case "length": + sample["status"] = Status.TRUNCATED + case "abort": + sample["status"] = Status.ABORTED + case "stop": + sample["status"] = Status.COMPLETED + + return sample + + async def generate_and_rm(self, sample: dict, evaluation=False) -> dict: + # For samples with existing response, check if they're complete + if sample["status"] == Status.COMPLETED or sample["status"] == Status.TRUNCATED: + assert sample["response"] is not None + # if not self.group_rm: + # assert sample.reward is not None + return sample + + # generate + async with self.semaphore: + if self.aborted: + sample["status"] = Status.ABORTED + return sample + + sample = await self.generate(sample) + + if sample["status"] == Status.ABORTED: + return sample + + return sample + + async def generate_group(self, group: list[Dict], evaluation=False) -> list[Dict]: + if self.aborted: + return group + + group = await asyncio.gather(*[self.generate_and_rm(sample, evaluation=evaluation) for sample in group]) + + return group + + def generate_sequences(self, prompts: DataProto): + return run(self.async_generate_sequences(prompts)) + + async def async_generate_sequences(self, prompts: DataProto): + """For compatibility with the original generate_sequences in verl""" + idx = prompts.batch["input_ids"] + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] + eos_token_id = prompts.meta_info["eos_token_id"] + batch_size = idx.size(0) + + # Extract non-tensor data + non_tensor_batch = prompts.non_tensor_batch + if "raw_prompt_ids" not in non_tensor_batch: + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)], + dtype=object, + ) + + if "multi_modal_data" in non_tensor_batch: + sglang_inputs = [] + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), + strict=True, + ): + sglang_inputs.append( + { + "prompt_token_ids": raw_prompt_ids, + "multi_modal_data": multi_modal_data, + "image_data": ( + multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None + ), + } + ) + else: + sglang_inputs = [ + {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] + + # Ensure token IDs are lists or numpy arrays + for input_data in sglang_inputs: + if isinstance(input_data["prompt_token_ids"], np.ndarray): + input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() + elif not isinstance(input_data["prompt_token_ids"], list): + raise TypeError( + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) + idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] + image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] + + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) + + # Create request-level sampling parameters + request_sampling_params = self.sampling_params.copy() + if not do_sample: + request_sampling_params.update( + { + "n": 1, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + "temperature": 0, + "top_p": 1, + "top_k": -1, + "ignore_eos": False, + "min_new_tokens": 0, + "max_new_tokens": self.config.response_length, + "skip_special_tokens": True, + "spaces_between_special_tokens": True, + } + ) + elif is_validate: + request_sampling_params.update( + { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } + ) + url = f"http://{self.sglang_router_ip}:{self.sglang_router_port}/generate" + + # 并发执行所有请求 + tasks = [ + post( + url, + { + "sampling_params": request_sampling_params, + "return_logprob": True, + "input_ids": idx, + "image_data": image_data, + }, + use_http2=self.use_http2, + ) + for idx, image_data in zip(idx_list, image_list) + ] + + outputs = [] + pbar = tqdm(total=len(tasks), desc="Run eval Rollout generation") + from loguru import logger as log + for coro in asyncio.as_completed(tasks): + output = await coro + outputs.append(output) + # log.info(f"output: {output}") + pbar.update(1) + pbar.close() + + out = _post_process_outputs(self.tokenizer, outputs) + response = out[0].to(idx.device) + rollout_log_probs = None + if self.config.calculate_log_probs: + rollout_log_probs = out[1].to(idx.device) + + if response.shape[1] < self.config.response_length: + response = pad_sequence_to_length(response, self.config.response_length, self.tokenizer.pad_token_id) + if self.config.calculate_log_probs: + rollout_log_probs = pad_sequence_to_length( + rollout_log_probs, self.config.response_length, self.tokenizer.pad_token_id + ) + + seq = torch.cat([idx, response], dim=-1) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + if position_ids.dim() == 3: # qwen2vl mrope + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + + # TODO(sgm): fix position_ids on right_pad + # prompt: left pad + response: right pad + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + response_position_ids = position_ids[..., -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + + # all the tp ranks should contain the same data here. data in all ranks are valid + batch = TensorDict( + { + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=batch_size, + ) + if self.config.calculate_log_probs: + # we will recompute old log prob with actor + batch["rollout_log_probs"] = rollout_log_probs + + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) diff --git a/verl/workers/rollout/sglang_rollout/http_server_engine.py b/verl/workers/rollout/sglang_rollout/http_server_engine.py new file mode 100644 index 00000000000..9b40bd3e445 --- /dev/null +++ b/verl/workers/rollout/sglang_rollout/http_server_engine.py @@ -0,0 +1,1039 @@ +# Copyright 2025 Zhipu AI +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from multiple sources: +# 1. THUDM/slime project +# Original source: https://github.com/THUDM/slime/blob/main/slime/backends/sglang_utils/http_server_engine.py +# Copyright 2025 Zhipu AI +# Licensed under the Apache License, Version 2.0 +# 2. SGLang project +# Original source: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server_engine.py +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 +# +# Modifications made by Zhipu AI and ModelBest Inc. include but are not limited to: +# - Enhanced error handling and retry logic +# - Added async support with connection pooling +# - Extended functionality for distributed weight updates +# - Improved logging and monitoring capabilities +# - Additional configuration options and optimizations + +"""HTTP Server Engine Adapter for SGLang. + +This module provides HTTP-based adapters for SGLang engines, allowing communication +with SGLang servers through HTTP requests instead of direct engine calls. + +Classes: + HttpServerEngineAdapter: Synchronous HTTP adapter for SGLang engines + AsyncHttpServerEngineAdapter: Asynchronous HTTP adapter for SGLang engines + +Functions: + launch_server_process: Launch and initialize an SGLang HTTP server process +""" + +import asyncio +import logging +import multiprocessing +import os +import time +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, List, Optional, Tuple + +import aiohttp +import requests +from sglang.srt.entrypoints.EngineBase import EngineBase +from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +# Default configuration constants +DEFAULT_TIMEOUT = 600.0 +DEFAULT_MAX_RETRIES = 3 +DEFAULT_RETRY_DELAY = 2.0 +DEFAULT_MAX_CONNECTIONS = 4000 + + +def launch_server_process(server_args: ServerArgs, timeout: float = DEFAULT_TIMEOUT) -> multiprocessing.Process: + """Launch an SGLang HTTP server process and wait for it to be ready. + + This function starts a new process running an SGLang HTTP server, then waits + for the server to become ready by polling its health endpoints. It ensures + the server is fully operational before returning. + + Args: + server_args (ServerArgs): Server configuration arguments including host, port, and other settings + timeout (float, optional): Timeout for individual HTTP requests during health checks. + Defaults to DEFAULT_TIMEOUT. + + Returns: + multiprocessing.Process: The launched multiprocessing.Process instance + + Raises: + RuntimeError: If the server process terminates unexpectedly during startup or cache flush + TimeoutError: If server fails to become ready within reasonable time (300 seconds) + requests.RequestException: If health check requests fail repeatedly + + Note: + This function will return immediately for non-master nodes (node_rank != 0), + but the process will still be started and returned. + """ + p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p.start() + + if server_args.node_rank != 0: + return p + + base_url = server_args.url() + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {server_args.api_key}", + } + + # Health check with overall timeout + start_time = time.time() + max_wait_time = 500.0 # 5 minutes max wait + + with requests.Session() as session: + while time.time() - start_time < max_wait_time: + if not p.is_alive(): + raise RuntimeError("Server process terminated unexpectedly during startup") + + try: + response = session.get(f"{base_url}/health_generate", headers=headers, timeout=timeout) + if response.status_code == 200: + break + except requests.RequestException as e: + logger.debug(f"Health check failed: {e}") + + time.sleep(2) + else: + p.terminate() + logger.error(f"Server in {base_url} failed to become healthy within timeout period") + raise TimeoutError("Server failed to become healthy within timeout period") + + # Ensure cache is ready + while time.time() - start_time < max_wait_time: + if not p.is_alive(): + raise RuntimeError("Server process terminated unexpectedly during cache flush") + + try: + response = session.get(f"{base_url}/flush_cache", headers=headers, timeout=timeout) + if response.status_code == 200: + break + except requests.RequestException as e: + logger.debug(f"Cache flush check failed: {e}") + + time.sleep(2) + else: + p.terminate() + raise TimeoutError("Server cache flush failed within timeout period") + + return p + + +class HttpServerEngineAdapter(EngineBase): + """HTTP-based adapter for SGLang engines. + + This adapter allows interaction with SGLang engines through HTTP requests + instead of direct engine calls. It launches an HTTP server process and + provides methods to communicate with it via REST API calls. + + You can use this class to launch a server from a HttpServerEngineAdapter instance. + We recommend using this class only when you need to use http server. + Otherwise, you can use Engine directly. + + Attributes: + router_ip (Optional[str]): IP address of the router for worker registration + router_port (Optional[int]): Port of the router for worker registration + server_args (ServerArgs): Server configuration arguments + node_rank (int): Rank of this node in distributed setup + process (multiprocessing.Process): The launched server process + timeout (float): HTTP request timeout in seconds + max_retries (int): Maximum number of retry attempts for failed requests + retry_delay (float): Base delay between retries in seconds + """ + + def __init__( + self, + router_ip: Optional[str] = None, + router_port: Optional[int] = None, + timeout: float = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + **kwargs: Any, + ) -> None: + """Initialize the HTTP server engine adapter. + + Args: + router_ip (Optional[str], optional): IP address of router for worker registration. + Defaults to None. + router_port (Optional[int], optional): Port of router for worker registration. + Defaults to None. + timeout (float, optional): HTTP request timeout in seconds. + Defaults to DEFAULT_TIMEOUT. + max_retries (int, optional): Maximum number of retry attempts for failed requests. + Defaults to DEFAULT_MAX_RETRIES. + retry_delay (float, optional): Base delay between retries in seconds. + Defaults to DEFAULT_RETRY_DELAY. + **kwargs (Any): Additional arguments passed to ServerArgs + + Note: + If both router_ip and router_port are provided and this is the master node + (node_rank == 0), the adapter will automatically register with the router. + """ + self.router_ip: Optional[str] = router_ip + self.router_port: Optional[int] = router_port + self.timeout: float = timeout + self.max_retries: int = max_retries + self.retry_delay: float = retry_delay + self.server_args: ServerArgs = ServerArgs(**kwargs) + self.node_rank: int = self.server_args.node_rank + + logger.info(f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}") + from loguru import logger as log + + log.info(f"self.server_args: {self.server_args}") + self.process: multiprocessing.Process = launch_server_process(self.server_args, self.timeout) + + if self.node_rank == 0 and self.router_ip and self.router_port: + self._register_with_router() + + def _register_with_router(self) -> None: + """Register worker with router with error handling. + + This method attempts to register the current worker with a router service. + If registration fails, it logs an error but does not raise an exception, + allowing the server to continue operating without router integration. + + Raises: + Does not raise exceptions - all errors are logged and handled gracefully. + """ + try: + url = f"http://{self.router_ip}:{self.router_port}/add_worker" + params = {"url": f"http://{self.server_args.host}:{self.server_args.port}"} + response = requests.post(url, params=params, timeout=self.timeout) + response.raise_for_status() + logger.info("Successfully registered with router") + except Exception as e: + logger.error(f"Failed to register with router: {e}") + # Don't raise here - server can still work without router + + def _make_request( + self, + endpoint: str, + payload: Optional[Dict[str, Any]] = None, + method: str = "POST", + timeout: float = DEFAULT_TIMEOUT, + only_master: bool = True, + ) -> Dict[str, Any]: + """Make a HTTP request with retry logic and consistent error handling. + + Args: + endpoint (str): The API endpoint to call (without leading slash) + payload (Optional[Dict[str, Any]], optional): The JSON payload to send. + Defaults to empty dict if None. + method (str, optional): HTTP method to use. Defaults to "POST". + + Returns: + Dict[str, Any]: The JSON response from the server + + Raises: + requests.HTTPError: If the HTTP request fails with a client/server error + RuntimeError: If all retry attempts are exhausted + + Note: + - For non-master nodes (node_rank != 0), returns empty dict immediately + - Uses exponential backoff for retries + - Logs warnings for timeout and connection errors, errors for HTTP errors + """ + if only_master and self.node_rank != 0: + return {} + + url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" + + for attempt in range(self.max_retries + 1): + try: + if method.upper() == "GET": + response = requests.get(url, timeout=self.timeout) + else: + response = requests.post(url, json=payload or {}, timeout=self.timeout) + + response.raise_for_status() + return response.json() + + except requests.exceptions.Timeout: + logger.warning(f"Request to {endpoint} timed out (attempt {attempt + 1})") + except requests.exceptions.ConnectionError: + logger.warning(f"Connection error for {endpoint} (attempt {attempt + 1})") + except requests.exceptions.HTTPError as e: + logger.error(f"HTTP error for {endpoint}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error for {endpoint}: {e}") + if attempt == self.max_retries: + raise + + if attempt < self.max_retries: + print(f"Retrying {endpoint} in {self.retry_delay * (2**attempt):.2f} seconds...", flush=True) + time.sleep(self.retry_delay * (2**attempt)) # Exponential backoff + + raise RuntimeError(f"Failed to complete request to {endpoint} after {self.max_retries + 1} attempts") + + def update_weights_from_tensor( + self, + serialized_named_tensors: List[str], + load_format: Optional[str] = None, + flush_cache: bool = False, + ) -> Dict[str, Any]: + """Update model weights from tensor data. + + The HTTP server will only post meta data, and the real weights will be + copied directly from GPUs. + + Args: + serialized_named_tensors (List[str]): List of serialized tensor data + load_format (Optional[str], optional): Format specification for loading weights. + Defaults to None. + flush_cache (bool, optional): Whether to flush cache after updating weights. + Defaults to False. + + Returns: + Dict[str, Any]: Server response containing update status + + Note: + The model should be on GPUs rather than CPU for this functionality to work properly. + If you encounter issues, ensure your model is loaded on GPU devices rather than CPU. + """ + return self._make_request( + "update_weights_from_tensor", + { + "serialized_named_tensors": serialized_named_tensors, + "load_format": load_format, + "flush_cache": flush_cache, + }, + ) + + def shutdown(self) -> None: + """Shutdown the HTTP server and clean up resources. + + This method performs the following cleanup operations: + 1. Unregisters the worker from the router (if configured) + 2. Terminates the server process tree + + All operations are performed with error handling to ensure graceful shutdown + even if individual steps fail. + + Note: + This method should be called when the adapter is no longer needed + to ensure proper cleanup of resources and processes. + """ + # Unregister from router + if self.router_ip and self.router_port: + try: + url = f"http://{self.router_ip}:{self.router_port}/remove_worker" + params = {"url": f"http://{self.server_args.host}:{self.server_args.port}"} + requests.post(url, params=params, timeout=5.0) # Short timeout for shutdown + logger.info("Successfully unregistered from router") + except Exception as e: + logger.warning(f"Failed to unregister from router: {e}") + + # Kill server process + if hasattr(self, "process") and self.process is not None: + try: + kill_process_tree(self.process.pid) + logger.info("Server process terminated") + except Exception as e: + logger.error(f"Failed to terminate server process: {e}") + + def generate( + self, + prompt: Optional[str] = None, + sampling_params: Optional[Dict[str, Any]] = None, + input_ids: Optional[List[int]] = None, + image_data: Optional[Any] = None, + return_logprob: bool = False, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + token_ids_logprob: Optional[List[int]] = None, + lora_path: Optional[str] = None, + custom_logit_processor: Optional[Callable] = None, + ) -> Dict[str, Any]: + """Generate text using the SGLang server. + + Args: + prompt (Optional[str], optional): Text prompt for generation. Defaults to None. + sampling_params (Optional[Dict[str, Any]], optional): Parameters controlling + text generation sampling. Defaults to None. + input_ids (Optional[List[int]], optional): Alternative to prompt, direct token IDs input. + Defaults to None. + image_data (Optional[Any], optional): Image data for multimodal generation. + Defaults to None. + return_logprob (bool, optional): Whether to return log probabilities. + Defaults to False. + logprob_start_len (Optional[int], optional): Starting length for log probability calculation. + Defaults to None. + top_logprobs_num (Optional[int], optional): Number of top log probabilities to return. + Defaults to None. + token_ids_logprob (Optional[List[int]], optional): Specific token IDs for + log probability calculation. Defaults to None. + lora_path (Optional[str], optional): Path to LoRA adapter weights. Defaults to None. + custom_logit_processor (Optional[Callable], optional): Custom logit processing function. + Defaults to None. + + Returns: + Dict[str, Any]: Generated text and associated metadata from the server + + Note: + Either prompt or input_ids should be provided, but not both. + The response format depends on the server configuration and parameters. + """ + payload = { + "text": prompt, + "sampling_params": sampling_params, + "input_ids": input_ids, + "image_data": image_data, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "token_ids_logprob": token_ids_logprob, + "lora_path": lora_path, + "custom_logit_processor": custom_logit_processor, + } + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + + return self._make_request("generate", payload, only_master=False) + + def flush_cache(self) -> Dict[str, Any]: + """Flush the cache of the server. + + This method repeatedly attempts to flush the server cache until successful. + The flush operation will not return status 200 when there are pending requests. + + Returns: + Dict[str, Any]: Server response indicating cache flush status. + For non-master nodes, returns empty dict. + + Note: + Uses retry logic with limited attempts (max_retries * 2) to avoid infinite loops. + Each retry includes a delay to allow pending requests to complete. + """ + if self.node_rank != 0: + return {} + + # Use retry logic with limited attempts to avoid infinite loops + for attempt in range(self.max_retries * 2): # Allow more retries for cache flush + try: + response = requests.get( + f"http://{self.server_args.host}:{self.server_args.port}/flush_cache", timeout=self.timeout + ) + if response.status_code == 200: + return {} + except Exception as e: + logger.warning(f"Error flushing cache (attempt {attempt + 1}): {e}") + + time.sleep(self.retry_delay) + + logger.error("Failed to flush cache after maximum attempts") + return {} + + def release_memory_occupation(self, tags: Optional[List[str]] = None) -> Dict[str, Any]: + """Release GPU memory occupation temporarily. + + Args: + tags (Optional[List[str]], optional): List of tags to specify which memory to release. + If None, releases all memory. Defaults to None. ["weights", "kv_cache"] + + Returns: + Dict[str, Any]: Server response indicating memory release status + """ + print("release_memory_occupation", flush=True) + return self._make_request("release_memory_occupation", {"tags": tags}) + + def resume_memory_occupation(self, tags: Optional[List[str]] = None) -> Dict[str, Any]: + """Resume GPU memory occupation. + + Args: + tags (Optional[List[str]], optional): List of tags to specify which memory to resume. + If None, resumes all memory. Defaults to None. ["weights", "kv_cache"] + + Returns: + Dict[str, Any]: Server response indicating memory resume status + """ + return self._make_request("resume_memory_occupation", {"tags": tags}) + + def init_weights_update_group( + self, master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str + ) -> Dict[str, Any]: + """Initialize a distributed weights update group. + + Args: + master_address (str): Address of the master node for distributed communication + master_port (int): Port of the master node + rank_offset (int): Offset for process ranks in the group + world_size (int): Total number of processes in the distributed group + group_name (str): Name identifier for the process group + backend (str): Backend to use for distributed communication (e.g., 'nccl', 'gloo') + + Returns: + Dict[str, Any]: Server response indicating group initialization status + """ + return self._make_request( + "init_weights_update_group", + { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + }, + ) + + def update_weights_from_distributed( + self, + names: List[str], + dtypes: List[Any], + shapes: List[Tuple[int, ...]], + group_name: str, + flush_cache: bool = False, + ) -> Dict[str, Any]: + """Update model weights from distributed tensors. + + Args: + names (List[str]): List of tensor names to update + dtypes (List[Any]): List of data types for each tensor (typically torch.dtype) + shapes (List[Tuple[int, ...]]): List of tensor shapes + group_name (str): Name of the distributed process group + flush_cache (bool, optional): Whether to flush cache after updating weights. + Defaults to False. + + Returns: + Dict[str, Any]: Server response indicating distributed update status + """ + return self._make_request( + "update_weights_from_distributed", + { + "names": names, + "dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes], + "shapes": shapes, + "group_name": group_name, + "flush_cache": flush_cache, + }, + ) + + def pause_generation(self) -> Dict[str, Any]: + """Pause text generation on the server. + + Returns: + Dict[str, Any]: Server response indicating pause status + """ + return self._make_request("pause_generation", {}) + + def continue_generation(self) -> Dict[str, Any]: + """Continue text generation on the server. + + Returns: + Dict[str, Any]: Server response indicating continuation status + """ + return self._make_request("continue_generation", {}) + + +class AsyncHttpServerEngineAdapter(HttpServerEngineAdapter): + """Asynchronous HTTP-based adapter for SGLang engines. + + This class inherits from HttpServerEngineAdapter and adds async capabilities + for non-blocking HTTP requests to the SGLang server. It provides the same + functionality as the synchronous version but with async/await support. + + The async adapter is useful when you need to make multiple concurrent requests + or integrate with async frameworks. It uses aiohttp for efficient async HTTP + communication and maintains connection pooling for better performance. + + Attributes: + _need_reload (bool): Flag indicating if weights need to be reloaded on first use + _session (Optional[aiohttp.ClientSession]): aiohttp ClientSession for making async HTTP requests + _session_lock (asyncio.Lock): Lock for thread-safe session access + max_connections (int): Maximum number of connections in the connection pool + """ + + def __init__( + self, + router_ip: Optional[str] = None, + router_port: Optional[int] = None, + timeout: float = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + retry_delay: float = DEFAULT_RETRY_DELAY, + max_connections: int = DEFAULT_MAX_CONNECTIONS, + **kwargs: Any, + ) -> None: + """Initialize the async HTTP server engine adapter. + + Args: + router_ip (Optional[str], optional): IP address of router for worker registration. + Defaults to None. + router_port (Optional[int], optional): Port of router for worker registration. + Defaults to None. + timeout (float, optional): HTTP request timeout in seconds. + Defaults to DEFAULT_TIMEOUT. + max_retries (int, optional): Maximum number of retry attempts for failed requests. + Defaults to DEFAULT_MAX_RETRIES. + retry_delay (float, optional): Base delay between retries in seconds. + Defaults to DEFAULT_RETRY_DELAY. + max_connections (int, optional): Maximum number of connections in the connection pool. + Defaults to DEFAULT_MAX_CONNECTIONS. + **kwargs (Any): Additional arguments passed to ServerArgs + """ + super().__init__(router_ip, router_port, timeout, max_retries, retry_delay, **kwargs) + # Similar to AsyncEngine, track if we need to reload weights + self._need_reload: bool = True + self._session: Optional[aiohttp.ClientSession] = None + self._session_lock: asyncio.Lock = asyncio.Lock() + self.max_connections: int = max_connections + + @asynccontextmanager + async def _get_session(self) -> aiohttp.ClientSession: + """Context manager for safe session access with proper connection pooling. + + Yields: + aiohttp.ClientSession: Session instance for making HTTP requests + + Note: + This method creates a new session for each request to avoid resource competition + while still maintaining proper connection pooling through the shared connector. + """ + # Create a new session for each request to avoid resource competition + connector = aiohttp.TCPConnector( + limit=self.max_connections, + limit_per_host=self.max_connections // 4, + ttl_dns_cache=300, + use_dns_cache=True, + ) + timeout = aiohttp.ClientTimeout(total=self.timeout) + session = aiohttp.ClientSession(connector=connector, timeout=timeout) + + try: + yield session + finally: + # Always close the session to free up resources + if not session.closed: + await session.close() + + async def _make_async_request( + self, + endpoint: str, + payload: Optional[Dict[str, Any]] = None, + method: str = "POST", + timeout: float = DEFAULT_TIMEOUT, + only_master: bool = True, + ) -> Dict[str, Any]: + """Make an async HTTP request with retry logic and consistent error handling. + + Args: + endpoint (str): The API endpoint to call (without leading slash) + payload (Optional[Dict[str, Any]], optional): The JSON payload to send. + Defaults to empty dict if None. + method (str, optional): HTTP method to use. Defaults to "POST". + + Returns: + Dict[str, Any]: The JSON response from the server + + Raises: + aiohttp.ClientResponseError: If the HTTP request fails with a client/server error + RuntimeError: If all retry attempts are exhausted + + Note: + - For non-master nodes (node_rank != 0), returns empty dict immediately + - Uses exponential backoff for retries + - Logs warnings for timeout and connection errors, errors for HTTP errors + """ + if only_master and self.node_rank != 0: + return {} + + url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}" + + # print(f"Making async request to {url} with method {method} and payload {payload}", flush=True) + + for attempt in range(self.max_retries + 1): + try: + async with self._get_session() as session: + if method.upper() == "GET": + async with session.get(url, timeout=timeout) as response: + response.raise_for_status() + return await response.json() + else: + async with session.post(url, json=payload or {}, timeout=timeout) as response: + response.raise_for_status() + return await response.json() + + except asyncio.TimeoutError: + logger.warning(f"Async request to {endpoint} timed out (attempt {attempt + 1})") + except aiohttp.ClientConnectorError: + logger.warning(f"Connection error for {endpoint} (attempt {attempt + 1})") + except aiohttp.ClientResponseError as e: + logger.error(f"HTTP error for {endpoint}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error for {endpoint}: {e}") + if attempt == self.max_retries: + raise + + if attempt < self.max_retries: + await asyncio.sleep(self.retry_delay * (2**attempt)) + + raise RuntimeError(f"Failed to complete async request to {endpoint} after {self.max_retries + 1} attempts") + + async def release_memory_occupation(self, tags: Optional[List[str]] = None) -> Dict[str, Any]: + """Release GPU memory occupation temporarily (async version). + + Args: + tags (Optional[List[str]], optional): List of tags to specify which memory to release. + If None, releases all memory. Defaults to None. ["weights", "kv_cache"] + + Returns: + Dict[str, Any]: Server response indicating memory release status + """ + return await self._make_async_request("release_memory_occupation", {"tags": tags}) + + async def resume_memory_occupation(self, tags: Optional[List[str]] = None) -> Dict[str, Any]: + """Resume GPU memory occupation (async version). + + Similar to AsyncEngine, this method handles first-time weight reloading + by calling release_memory_occupation if needed. + + Args: + tags (Optional[List[str]], optional): List of tags to specify which memory to resume. + If None, resumes all memory. Defaults to None. ["weights", "kv_cache"] + + Returns: + Dict[str, Any]: Server response indicating memory resume status + """ + # Similar to AsyncEngine, handle first-time reload + if self._need_reload: + await self.release_memory_occupation() + self._need_reload = False + + return await self._make_async_request("resume_memory_occupation", {"tags": tags}) + + async def update_weights_from_tensor( + self, + named_tensors: List[str], + load_format: Optional[str] = None, + flush_cache: bool = True, + ) -> Dict[str, Any]: + """Update model weights from tensor data asynchronously. + + Args: + serialized_named_tensors (List[str]): List of serialized tensor data + load_format (Optional[str], optional): Format specification for loading weights. + Defaults to None. + flush_cache (bool, optional): Whether to flush cache after updating weights. + Defaults to True. + + Returns: + Dict[str, Any]: Server response containing update status + """ + # serialized_named_tensors=[ + # MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) + # ] + import base64 + + serialized_named_tensors = [ + base64.b64encode(MultiprocessingSerializer.serialize(named_tensors)).decode("utf-8") + for _ in range(self.server_args.tp_size) + ] + return await self._make_async_request( + "update_weights_from_tensor", + { + "serialized_named_tensors": serialized_named_tensors, + "load_format": load_format, + "flush_cache": flush_cache, + }, + ) + + async def flush_cache(self) -> Dict[str, Any]: + """Flush the cache of the server asynchronously. + + Similar to the sync version, this method retries until the cache + is successfully flushed. It uses async sleep between retries. + + Returns: + Dict[str, Any]: Server response indicating cache flush status. + For non-master nodes, returns empty dict. + + Note: + Uses retry logic with limited attempts (max_retries * 2) to avoid infinite loops. + Each retry includes an async delay to allow pending requests to complete. + """ + if self.node_rank != 0: + return {} + + # Use retry logic with limited attempts to avoid infinite loops + for attempt in range(self.max_retries * 5): # Allow more retries for cache flush + try: + async with self._get_session() as session: + url = f"http://{self.server_args.host}:{self.server_args.port}/flush_cache" + async with session.get(url) as response: + if response.status == 200: + return {} + except Exception as e: + logger.warning(f"Error flushing cache (attempt {attempt + 1}): {e}") + + await asyncio.sleep(self.retry_delay) + + logger.error("Failed to flush cache after maximum attempts") + return {} + + async def generate( + self, + prompt: Optional[str] = None, + sampling_params: Optional[Dict[str, Any]] = None, + input_ids: Optional[List[int]] = None, + image_data: Optional[Any] = None, + return_logprob: bool = False, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + token_ids_logprob: Optional[List[int]] = None, + lora_path: Optional[str] = None, + custom_logit_processor: Optional[Callable] = None, + ) -> Dict[str, Any]: + """Generate text using the SGLang server asynchronously.""" + t_start = time.perf_counter() + logger.info("generate() started") + + payload = { + "text": prompt, + "sampling_params": sampling_params, + "input_ids": input_ids, + "image_data": image_data, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "token_ids_logprob": token_ids_logprob, + "lora_path": lora_path, + "custom_logit_processor": custom_logit_processor, + } + + # Filter out None values + payload = {k: v for k, v in payload.items() if v is not None} + from datetime import datetime + + now_start = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 精确到毫秒 + # Send request + response = await self._make_async_request("generate", payload, timeout=self.timeout, only_master=False) + t_after_request = time.perf_counter() + + total_time = t_after_request - t_start + now_end = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 精确到毫秒 + # print(f"[{now_start} to {now_end}] generate() completed in {total_time*1000:.2f} ms with response length: {len(response['text'])}") + + return response + + async def init_weights_update_group( + self, master_address: str, master_port: int, rank_offset: int, world_size: int, group_name: str, backend: str + ) -> Dict[str, Any]: + """Initialize a distributed weights update group asynchronously. + + Args: + master_address (str): Address of the master node for distributed communication + master_port (int): Port of the master node + rank_offset (int): Offset for process ranks in the group + world_size (int): Total number of processes in the distributed group + group_name (str): Name identifier for the process group + backend (str): Backend to use for distributed communication (e.g., 'nccl', 'gloo') + + Returns: + Dict[str, Any]: Server response indicating group initialization status + """ + return await self._make_async_request( + "init_weights_update_group", + { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + }, + ) + + async def update_weights_from_distributed( + self, + names: List[str], + dtypes: List[Any], + shapes: List[Tuple[int, ...]], + group_name: str, + flush_cache: bool = False, + ) -> Dict[str, Any]: + """Update model weights from distributed tensors asynchronously. + + Args: + names (List[str]): List of tensor names to update + dtypes (List[Any]): List of data types for each tensor (typically torch.dtype) + shapes (List[Tuple[int, ...]]): List of tensor shapes + group_name (str): Name of the distributed process group + flush_cache (bool, optional): Whether to flush cache after updating weights. + Defaults to False. + + Returns: + Dict[str, Any]: Server response indicating distributed update status + """ + return await self._make_async_request( + "update_weights_from_distributed", + { + "names": names, + "dtypes": [str(dtype).replace("torch.", "") for dtype in dtypes], + "shapes": shapes, + "group_name": group_name, + "flush_cache": flush_cache, + }, + ) + + async def pause_generation(self) -> Dict[str, Any]: + """Pause text generation on the server asynchronously. + + Returns: + Dict[str, Any]: Server response indicating pause status + """ + return await self._make_async_request("pause_generation", {}) + + async def continue_generation(self) -> Dict[str, Any]: + """Continue text generation on the server asynchronously. + + Returns: + Dict[str, Any]: Server response indicating continuation status + """ + return await self._make_async_request("continue_generation", {}) + + async def async_generate( + self, + prompt: Optional[str] = None, + sampling_params: Optional[Dict[str, Any]] = None, + input_ids: Optional[List[int]] = None, + image_data: Optional[Any] = None, + return_logprob: bool = False, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + token_ids_logprob: Optional[List[int]] = None, + lora_path: Optional[str] = None, + custom_logit_processor: Optional[Callable] = None, + ) -> Dict[str, Any]: + """Async generate method that mirrors AsyncEngine.async_generate interface. + + This method provides compatibility with AsyncEngine's async_generate method + by forwarding the call to the generate method. It ensures API consistency + between direct engine usage and HTTP-based engine usage. + + Args: + prompt (Optional[str], optional): Text prompt for generation. Defaults to None. + sampling_params (Optional[Dict[str, Any]], optional): Parameters controlling + text generation sampling. Defaults to None. + input_ids (Optional[List[int]], optional): Alternative to prompt, direct token IDs input. + Defaults to None. + image_data (Optional[Any], optional): Image data for multimodal generation. + Defaults to None. + return_logprob (bool, optional): Whether to return log probabilities. + Defaults to False. + logprob_start_len (Optional[int], optional): Starting length for log probability calculation. + Defaults to None. + top_logprobs_num (Optional[int], optional): Number of top log probabilities to return. + Defaults to None. + token_ids_logprob (Optional[List[int]], optional): Specific token IDs for + log probability calculation. Defaults to None. + lora_path (Optional[str], optional): Path to LoRA adapter weights. Defaults to None. + custom_logit_processor (Optional[Callable], optional): Custom logit processing function. + Defaults to None. + + Returns: + Dict[str, Any]: Generated text and associated metadata from the server + + Note: + This method is provided for API compatibility with AsyncEngine. + It forwards all calls to the generate method. + """ + return await self.generate( + prompt=prompt, + sampling_params=sampling_params, + input_ids=input_ids, + image_data=image_data, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + token_ids_logprob=token_ids_logprob, + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + ) + + async def close(self) -> None: + """Close the aiohttp session and clean up resources. + + This method should be called when the adapter is no longer needed + to ensure proper cleanup of HTTP connections and resources. + + Note: + This method is safe to call multiple times. If the session is + already closed or None, this method will do nothing. + """ + if self._session and not self._session.closed: + await self._session.close() + self._session = None + logger.info("HTTP session closed") + + async def __aenter__(self) -> "AsyncHttpServerEngineAdapter": + """Async context manager support. + + Returns: + AsyncHttpServerEngineAdapter: Self for use in async context + """ + return self + + async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None: + """Cleanup on context exit. + + Args: + exc_type (Optional[type]): Exception type if an exception occurred + exc_val (Optional[Exception]): Exception value if an exception occurred + exc_tb (Optional[Any]): Exception traceback if an exception occurred + """ + await self.close() + + def __del__(self) -> None: + """Cleanup when object is destroyed. + + This provides a fallback cleanup mechanism for the aiohttp session + in case the close() method wasn't called explicitly. Note that this + is not ideal for async cleanup but provides a safety net. + + Warning: + This method attempts async cleanup in a sync context, which may + not always work reliably. It's recommended to explicitly call + close() or use the async context manager instead. + """ + if hasattr(self, "_session") and self._session and not self._session.closed: + # Note: This is not ideal for async cleanup, but provides a fallback + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self._session.close()) + else: + loop.run_until_complete(self._session.close()) + except Exception: + pass # Ignore cleanup errors during destruction + diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 3c6694325b0..4e029fedf73 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -17,8 +17,11 @@ import asyncio import logging +import multiprocessing import multiprocessing as mp import os +import random +import socket import time from copy import deepcopy from json import JSONDecodeError @@ -26,10 +29,12 @@ from uuid import uuid4 import numpy as np +import ray import sglang.srt.entrypoints.engine import torch import torch.distributed as dist from omegaconf import DictConfig +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.managers.tokenizer_manager import ( ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, @@ -43,7 +48,7 @@ get_ip, get_open_port, is_cuda, - maybe_set_triton_cache_manager, + # maybe_set_triton_cache_manager, set_prometheus_multiproc_dir, set_ulimit, ) @@ -69,6 +74,7 @@ FinishReasonTypeEnum, Message, ) +from verl.workers.rollout.sglang_rollout.http_server_engine import AsyncHttpServerEngineAdapter from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj try: @@ -104,9 +110,9 @@ def _set_envs_and_config(server_args: ServerArgs): set_ulimit() # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() + # if server_args.tp_size * server_args.dp_size > 1: + # # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + # maybe_set_tritwon_cache_manager() # Check flashinfer version if server_args.attention_backend == "flashinfer": @@ -126,7 +132,7 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) -sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config +# sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config # because chatCompletion is an async method, it makes the whole ray actor be an async actor @@ -258,6 +264,8 @@ def __init__( port=None, trust_remote_code: bool = False, device_mesh: DeviceMesh | None = None, + sglang_router_ip=None, + sglang_router_port=None, **kwargs, ): """Synchronized SGLang rollout engine. @@ -285,6 +293,8 @@ def __init__( super().__init__() self.config = config self._device_mesh_cpu = device_mesh + self.sglang_router_ip = sglang_router_ip + self.sglang_router_port = sglang_router_port os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") ( @@ -307,7 +317,11 @@ def __init__( self._verify_config(model_hf_config=model_hf_config) # initialize the inference engine - self._init_inference_engine(trust_remote_code, actor_module, port) + self._init_inference_engine( + trust_remote_code, + actor_module, + port, + ) self._init_sampling_params(**kwargs) @@ -437,7 +451,7 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): if first_rank_in_node: rank = dist.get_rank() os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = AsyncEngine( + self._engine = AsyncHttpServerEngineAdapter( model_path=actor_module, dtype=self.config.dtype, mem_fraction_static=self.config.gpu_memory_utilization, @@ -450,20 +464,26 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): dist_init_addr=dist_init_addr, nnodes=nnodes, trust_remote_code=trust_remote_code, + disable_cuda_graph=True, # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new # when random.seed is being set during training - port=30000 + rank, + host=get_host_info()[1], + # port=port, + port=30000 + rank + 2, # NOTE(Chenyang): if you want to debug the SGLang engine output # please set the following parameters # Otherwise, it will make the engine run too slow - # log_level="INFO", + log_level="info", # log_requests=True, # log_requests_level=2, # max_running_requests=1, mm_attention_backend="fa3", - attention_backend="fa3", + # attention_backend=attention_backend if attention_backend is not None else "fa3", # In async mode, we want token in token out. skip_tokenizer_init=self.config.mode == "async", + router_ip=self.sglang_router_ip, + router_port=self.sglang_router_port, + skip_server_warmup=True, ) else: self._engine = None @@ -1389,3 +1409,172 @@ async def sleep(self): return await self.sharding_manager.sleep() self.is_sleep = True + + +def create_rollout_engines(args, pg): + if args.debug_train_only: + return [] + + num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.rollout_num_gpus_per_node) + num_engines = args.rollout_num_gpus // num_gpu_per_engine + + pg, reordered_bundle_indices = pg + + RolloutRayActor = ray.remote(SGLangEngine) + + rollout_engines = [] + for i in range(num_engines): + num_gpus = 0.2 + num_cpus = num_gpus + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=reordered_bundle_indices[i * num_gpu_per_engine], + ) + + rollout_engines.append( + RolloutRayActor.options( + num_cpus=num_cpus, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote(args, rank=i) + ) + + # get ports + # there are 4 ports we need to allocate + # 1. server port + # 2. nccl port + # 3. dist_init_addr port + # 4. other ports for dp_attention, which is of size 4 + dp_size + num_engines_per_node = max( + 1, min(args.rollout_num_gpus_per_node, args.rollout_num_gpus) // args.rollout_num_gpus_per_engine + ) + addr_and_ports = [{} for _ in range(num_engines)] + for rank, engine in enumerate(rollout_engines): + if rank % num_engines_per_node != 0: + continue + + def get_addr_and_ports(): + # use small ports to prevent ephemeral port between 32768 and 65536. + start_port = 10000 + + def port(consecutive=1): + nonlocal start_port + _, port = ray.get( + engine._get_current_node_ip_and_free_port.remote( + start_port=start_port, + consecutive=consecutive, + ) + ) + start_port = port + consecutive + return port + + def addr(): + addr, _ = ray.get(engine._get_current_node_ip_and_free_port.remote()) + return addr + + return addr, port + + get_addr, get_port = get_addr_and_ports() + + for i in range(num_engines_per_node): + addr_and_ports[rank + i]["port"] = get_port() + addr_and_ports[rank + i]["nccl_port"] = get_port() + + if args.rollout_num_gpus_per_engine > args.rollout_num_gpus_per_node: + num_node_per_engine = args.rollout_num_gpus_per_engine // args.rollout_num_gpus_per_node + if rank % num_node_per_engine == 0: + # this is the first node in the engine, we need to allocate the dist_init_addr port + dist_init_addr = f"{get_addr()}:{get_port(6 + args.sglang_dp_size)}" + for i in range(num_node_per_engine): + addr_and_ports[rank + i]["dist_init_addr"] = dist_init_addr + else: + for i in range(num_engines_per_node): + addr_and_ports[rank + i]["dist_init_addr"] = f"{get_addr()}:{get_port(6 + args.sglang_dp_size)}" + + for i in range(num_engines): + for key in ["port", "nccl_port", "dist_init_addr"]: + assert key in addr_and_ports[i], f"Engine {i} {key} is not set." + print(f"Ports for engine {i}: {addr_and_ports[i]}") + + # TODO: don't ray.get here to overlap train actor init with rollout engine init. + # somehow if we don't sync here, the --debug-rollout-only mode will crash. + init_handles = [engine.init.remote(**ports) for engine, ports in zip(rollout_engines, addr_and_ports)] + ray.get(init_handles) + + return rollout_engines + + +def find_available_port(base_port: int): + port = base_port + random.randint(100, 1000) + while True: + if is_port_available(port): + return port + if port < 60000: + port += 42 + else: + port -= 43 + + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except socket.error: + return False + except OverflowError: + return False + + +def get_host_info(): + hostname = socket.gethostname() + + local_ip = socket.gethostbyname(hostname) + + return hostname, local_ip + + +def run_router(args): + try: + from sglang_router.launch_router import launch_router + + router = launch_router(args) + if router is None: + return 1 + return 0 + except Exception as e: + print(e) + return 1 + + +def _start_router(): + from sglang_router.launch_router import RouterArgs + + sglang_router_host = get_host_info()[1] + sglang_router_port = find_available_port(random.randint(3000, 4000)) + + router_args = RouterArgs( + host=sglang_router_host, + port=sglang_router_port, + balance_abs_threshold=0, + ) + + if hasattr(router_args, "log_level"): + router_args.log_level = "warn" + + process = multiprocessing.Process( + target=run_router, + args=(router_args,), + ) + process.daemon = True # 设置为守护进程 + process.start() + # 等待3秒 + time.sleep(3) + assert process.is_alive() + print(f"SGLang router launched at {sglang_router_host}:{sglang_router_port}") + return sglang_router_host, sglang_router_port diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index d353c70e84a..ad934d92c05 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -20,11 +20,14 @@ import asyncio import logging import os +from contextlib import asynccontextmanager +from typing import Optional import torch.distributed as dist from omegaconf import DictConfig from sglang.srt.entrypoints.engine import Engine from sglang.srt.model_executor.model_runner import LocalSerializedTensor +from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.utils import MultiprocessingSerializer from torch import nn from torch.distributed.device_mesh import DeviceMesh @@ -101,6 +104,8 @@ def __init__( self.device_mesh = device_mesh self.bridge = bridge self.offload_param = offload_param + self.multi_stage_wake_up = True + self._need_reload = True if self.device_mesh is not None: self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0] @@ -118,6 +123,18 @@ def __init__( else: self.gen_random_states = None + @GPUMemoryLogger(role="MegatronSGLangShardingManager prepare_for_generate", logger=logger) + def prepare_for_generate(self): + self.timing = {} + with simple_timer("reshard", self.timing): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.wake_up()) + + @GPUMemoryLogger(role="MegatronSGLangShardingManager finish_generate ", logger=logger) + def finish_generate(self): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.sleep()) + @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) def __enter__(self): self.timing = {} @@ -130,6 +147,14 @@ def __exit__(self, exc_type, exc_value, traceback): loop = asyncio.get_event_loop() loop.run_until_complete(self.sleep()) + async def release_memory(self): + if self.device_mesh["tp"].get_local_rank() == 0: + await self.inference_engine.release_memory_occupation() + + async def resume_memory(self, tags: Optional[list[str]] = None): + if self.device_mesh["tp"].get_local_rank() == 0: + await self.inference_engine.resume_memory_occupation(tags=tags) + async def update_weights(self, params): """ Update model weights using tensor buckets, similar to THUDM/slime's implementation. @@ -143,10 +168,11 @@ async def update_weights(self, params): - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 """ - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.resume_memory_occupation() + # if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + # await self.inference_engine.resume_memory_occupation() named_tensors = params load_format = None + monkey_patch_torch_reductions() update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): @@ -200,30 +226,51 @@ async def update_weights(self, params): ) if self.device_mesh["tp"].get_local_rank() == 0: - await self.inference_engine.flush_cache() + pass + # await self.inference_engine.flush_cache() - async def release_memory(self): - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.release_memory_occupation() + @asynccontextmanager + async def offload_manager(self): + try: + if self.multi_stage_wake_up: + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + await self.resume_memory(tags=["weights"]) + log_gpu_memory_usage("After resume SGLang weights in sharding manager", logger=logger) + else: + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + await self.resume_memory() + log_gpu_memory_usage("After resume SGLang weights + kv_cache in sharding manager", logger=logger) + dist.barrier() + + if self.offload_param: + load_megatron_model_to_gpu(self.actor_module) + yield + finally: + if self.offload_param: + offload_megatron_model_to_cpu(self.actor_module) + get_torch_device().empty_cache() + dist.barrier() + + await self.resume_memory(tags=["kv_cache"]) + log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) async def wake_up(self): - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - await self.update_weights(per_tensor_param) - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - get_torch_device().empty_cache() + async with self.offload_manager(): + if self.bridge is not None: + per_tensor_param = self.bridge.export_weights(self.actor_module) + else: + per_tensor_param = per_tensor_generator( + self.actor_module, + self.model_config, + self.weight_converter, + self.transformer_config, + self.layer_name_mapping, + ) + await self.update_weights(per_tensor_param) + # if self.offload_param: + # offload_megatron_model_to_cpu(self.actor_module) + # get_torch_device().empty_cache() # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = get_torch_device().get_rng_state() @@ -231,10 +278,10 @@ async def wake_up(self): @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) async def sleep(self): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + await self.release_memory() + log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + # dist.barrier() for model in self.actor_module: model.train()