Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions swift/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import asyncio
import inspect
import multiprocessing
import os
import time
import torch
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
Expand All @@ -15,7 +17,8 @@
from swift.metrics import Metric
from swift.model import get_processor
from swift.template import Template
from swift.utils import get_device, get_dist_setting, get_logger, is_dist, safe_snapshot_download
from swift.utils import (disable_deepspeed_zero3, get_device, get_dist_setting, get_logger, is_dist,
safe_snapshot_download)
from .infer_engine import InferEngine
from .patch import patch_auto_tokenizer
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
Expand Down Expand Up @@ -48,6 +51,43 @@
dtype_mapping = {torch.float16: 'float16', torch.bfloat16: 'bfloat16', torch.float32: 'float32'}


def _patch_vllm_dp_coordinator_timeout():
# https://github.com/vllm-project/vllm/pull/37452 introduced a 30-second default timeout,
# which is prone to timing out in spawn scenarios. Patch it to 180 seconds here.
try:
from vllm.v1.engine import coordinator as coordinator_module
except ImportError:
return

coordinator_cls = coordinator_module.DPCoordinator
if not hasattr(coordinator_cls, '_wait_for_zmq_addrs'):
return

if getattr(coordinator_cls, '_swift_timeout_patched', False):
return

def _wait_for_zmq_addrs(self, zmq_addr_pipe):
t0 = time.monotonic()
try:
ready = multiprocessing.connection.wait([zmq_addr_pipe, self.proc.sentinel], timeout=180)
elapsed = time.monotonic() - t0
if not ready:
raise RuntimeError(f'DP Coordinator process failed to report ZMQ addresses '
f'within 180s (elapsed={elapsed:.1f}s).')
try:
return zmq_addr_pipe.recv()
except EOFError:
raise RuntimeError('DP Coordinator process failed during startup.') from None
finally:
zmq_addr_pipe.close()

coordinator_cls._wait_for_zmq_addrs = _wait_for_zmq_addrs
coordinator_cls._swift_timeout_patched = True


_patch_vllm_dp_coordinator_timeout()


class VllmEngine(InferEngine):

def __init__(
Expand Down Expand Up @@ -180,7 +220,8 @@ def _get_processor(self):
task_type=self.task_type)

def _prepare_engine(self) -> None:
with patch_auto_tokenizer(self.tokenizer), self._patch_auto_config():
with patch_auto_tokenizer(self.tokenizer), self._patch_auto_config(), \
disable_deepspeed_zero3():
llm_engine_cls = AsyncLLMEngine if self.use_async_engine else LLMEngine
engine = llm_engine_cls.from_engine_args(self.engine_args)
self.engine = engine
Expand Down
12 changes: 11 additions & 1 deletion swift/megatron/trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _init_rollout_engine(self):
# Server mode uses external vLLM server
if self.is_main_process:
self.vllm_client.get_engine_type()
self.vllm_client.reset_mm_cache()
Comment thread
hjh0119 marked this conversation as resolved.
enable_lora = [self.vllm_client.enable_lora]
else:
enable_lora = [False]
Expand All @@ -236,6 +237,7 @@ def _init_rollout_engine(self):
with context():
set_expandable_segments(False)
self.engine = self._prepare_vllm_engine()
self.engine.engine.reset_mm_cache()
if args.sleep_level > 0:
self.engine.engine.sleep(args.sleep_level)
set_expandable_segments(True)
Expand Down Expand Up @@ -322,11 +324,19 @@ def _move_model_to_vllm(self):
else:
self._move_adapter_to_vllm()

# Reset prefix cache
self._reset_vllm_cache()

def _reset_vllm_cache(self):
# Reset prefix cache and encoder cache
vllm_ge_16 = check_vllm_version_ge('0.16')
Comment thread
hjh0119 marked this conversation as resolved.
if self.vllm_mode == 'server' and self.is_main_process:
self.vllm_client.reset_prefix_cache()
if vllm_ge_16:
self.vllm_client.reset_encoder_cache()
elif self.vllm_mode == 'colocate':
self.engine.engine.reset_prefix_cache()
if vllm_ge_16:
self.engine.engine.reset_encoder_cache()

def _move_full_model_to_vllm(self):
"""Transfer full model weights to vLLM engine.
Expand Down
61 changes: 55 additions & 6 deletions swift/pipelines/infer/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from transformers.utils import is_torch_npu_available
from typing import Dict, List, Optional, Union

from swift.arguments import RolloutArguments
Expand Down Expand Up @@ -105,10 +106,17 @@ def init_communicator(self, host: str, port: int, world_size: int) -> None:
Total number of participating processes in the update group.
"""
if self.communicator is not None:
raise RuntimeError('Weight update group already initialized. Call close_communicator first.')
return

# Get the rank of the current worker in the global world group.
rank = get_world_group().rank
parallel_config = getattr(getattr(self, 'vllm_config', None), 'parallel_config', None)
dp_index = int(getattr(parallel_config, 'data_parallel_index', 0)) if parallel_config is not None else 0
if dp_index > 0:
dp_rank = dp_index
tp_size = int(parallel_config.tensor_parallel_size)
else:
dp_rank = int(os.environ.get('SWIFT_ROLLOUT_DP_RANK', '0'))
tp_size = int(os.environ.get('SWIFT_ROLLOUT_TP_RANK', '1'))
rank = get_world_group().rank + dp_rank * tp_size

# Create a stateless process group to manage communication between training processes and vLLM workers.
# Initialize the NCCL-based communicator for weight synchronization.
Expand Down Expand Up @@ -302,13 +310,34 @@ def get_rollout_engine_type(args: RolloutArguments, engine: GRPOVllmEngine):
return rollout_engine


def _set_visible_devices_for_dp_rank(data_parallel_rank: int, tensor_parallel_size: int):

def _get_device_env_var():
if is_torch_npu_available():
return 'ASCEND_RT_VISIBLE_DEVICES'
return 'CUDA_VISIBLE_DEVICES'

env_var = _get_device_env_var()
current = os.environ.get(env_var)
if current:
all_devices = current.split(',')
else:
from swift.utils import get_device_count
all_devices = [str(i) for i in range(get_device_count())]

start = data_parallel_rank * tensor_parallel_size
end = start + tensor_parallel_size
selected = all_devices[start:end]
os.environ[env_var] = ','.join(selected)


def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None:
try:
args._import_external_plugins()
os.environ['VLLM_DP_RANK'] = str(data_parallel_rank)
os.environ['VLLM_DP_RANK_LOCAL'] = str(data_parallel_rank)
os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size)
_set_visible_devices_for_dp_rank(data_parallel_rank, args.vllm_tensor_parallel_size)
os.environ['VLLM_DP_MASTER_PORT'] = str(master_port)
os.environ['SWIFT_ROLLOUT_DP_RANK'] = str(data_parallel_rank)
os.environ['SWIFT_ROLLOUT_TP_RANK'] = str(args.vllm_tensor_parallel_size)
worker_seed = get_seed()
engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(), seed=worker_seed)
rollout_engine = get_rollout_engine_type(args, engine)
Expand Down Expand Up @@ -345,6 +374,8 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast
connection: Connection) -> None:
try:
args._import_external_plugins()
os.environ['SWIFT_ROLLOUT_DP_RANK'] = str(data_parallel_rank)
os.environ['SWIFT_ROLLOUT_TP_RANK'] = str(args.vllm_tensor_parallel_size)
worker_seed = get_seed()
engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(), seed=worker_seed)
rollout_engine = get_rollout_engine_type(args, engine)
Expand Down Expand Up @@ -396,6 +427,8 @@ def _register_rl_rollout_app(self):
self.app.post('/update_adapter_param/')(self.update_adapter_param)
self.app.post('/update_flattened_params/')(self.update_flattened_params)
self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache)
self.app.post('/reset_encoder_cache/')(self.reset_encoder_cache)
self.app.post('/reset_mm_cache/')(self.reset_mm_cache)
self.app.post('/close_communicator/')(self.close_communicator)
self.app.post('/infer/', response_model=None)(self.infer)
self.app.post('/get_engine_type/')(self.get_engine_type)
Expand Down Expand Up @@ -632,6 +665,22 @@ async def reset_prefix_cache(self):
success = all(output for output in all_outputs)
return {'message': 'Request received, resetting prefix cache status: ' + str(success)}

async def reset_encoder_cache(self):
"""Resets the encoder cache (vision encoder embeddings) for the model."""
for connection in self.connections:
connection.send({'type': 'call', 'method': 'reset_encoder_cache'})
all_outputs = [connection.recv() for connection in self.connections]
Comment thread
hjh0119 marked this conversation as resolved.
success = all(output for output in all_outputs)
return {'message': 'Request received, resetting encoder cache status: ' + str(success)}

async def reset_mm_cache(self):
"""Resets the multimodal processor cache for the model."""
for connection in self.connections:
connection.send({'type': 'call', 'method': 'reset_mm_cache'})
all_outputs = [connection.recv() for connection in self.connections]
Comment thread
hjh0119 marked this conversation as resolved.
success = all(output for output in all_outputs)
return {'message': 'Request received, resetting mm cache status: ' + str(success)}
Comment thread
hjh0119 marked this conversation as resolved.

async def get_engine_type(self):
"""
Return a dictionary describing the runtime engine configuration.
Expand Down
12 changes: 11 additions & 1 deletion swift/rlhf_trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _prepare_vllm(self):
if self.vllm_mode == 'server':
if self.accelerator.is_main_process:
self.vllm_client.get_engine_type()
self.vllm_client.reset_mm_cache()
Comment thread
hjh0119 marked this conversation as resolved.
vllm_use_async_engine = [self.vllm_client.use_async_engine]
use_gym_env = [self.vllm_client.use_gym_env]
enable_multi_turn = [self.vllm_client.enable_multi_turn]
Expand Down Expand Up @@ -246,6 +247,7 @@ def _prepare_vllm(self):

with context():
self.engine = self._prepare_vllm_engine()
self.engine.engine.reset_mm_cache()
if args.sleep_level > 0:
self.engine.engine.sleep(args.sleep_level)
self.dynamic_num_samples = False # grpo multi-turn
Expand Down Expand Up @@ -451,11 +453,19 @@ def _move_model_to_vllm(self, skip_async_check=False):
else:
self._move_adapter_to_vllm()

# Reset prefix cache
self._reset_vllm_cache()

def _reset_vllm_cache(self):
# Reset prefix cache and encoder cache after weight update
vllm_ge_16 = check_vllm_version_ge('0.16')
Comment thread
hjh0119 marked this conversation as resolved.
if self.vllm_mode == 'server' and self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
if vllm_ge_16:
self.vllm_client.reset_encoder_cache()
elif self.vllm_mode == 'colocate':
self.engine.engine.reset_prefix_cache()
if vllm_ge_16:
self.engine.engine.reset_encoder_cache()

def _move_adapter_to_vllm(self):
"""Transfer LoRA adapter weights to vLLM engine"""
Expand Down
38 changes: 38 additions & 0 deletions swift/rlhf_trainers/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,44 @@ def _reset_single_server(i):
if all_errors:
raise RuntimeError(f'Multiple errors on reset_prefix_cache: {all_errors}')

def reset_encoder_cache(self):
errors = [None] * self.num_servers

def _reset_single_server(i):
try:
response = self.sessions[i].post(f'{self.base_urls[i]}/reset_encoder_cache/')
if response.status_code != 200:
raise Exception(f'Server {i} reset failed: {response.text}')
except Exception as e:
errors[i] = e

with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)]
for future in futures:
future.result()
all_errors = [e for e in errors if e is not None]
if all_errors:
raise RuntimeError(f'Multiple errors on reset_encoder_cache: {all_errors}')

def reset_mm_cache(self):
errors = [None] * self.num_servers

def _reset_single_server(i):
try:
response = self.sessions[i].post(f'{self.base_urls[i]}/reset_mm_cache/')
if response.status_code != 200:
raise Exception(f'Server {i} reset failed: {response.text}')
except Exception as e:
errors[i] = e

with ThreadPoolExecutor(max_workers=self.num_servers) as executor:
futures = [executor.submit(_reset_single_server, i) for i in range(self.num_servers)]
for future in futures:
future.result()
all_errors = [e for e in errors if e is not None]
if all_errors:
raise RuntimeError(f'Multiple errors on reset_mm_cache: {all_errors}')

Comment thread
hjh0119 marked this conversation as resolved.
def get_engine_type(self):
# assume that all server has same engine type
response = self.sessions[0].post(f'{self.base_urls[0]}/get_engine_type/')
Expand Down
Loading