Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,15 @@
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger

logger = init_logger(__name__)
try:
from vllm.logger import init_logger

logger = init_logger(__name__)
except ImportError:
import logging

logger = logging.getLogger(__name__)

# Add uvloop for faster event loop if available
try:
Expand All @@ -149,6 +155,9 @@ class InstanceType:
DECODE: str = "decode"


TAINT_PRIORITY = 1e15


class ServerState:
def __init__(self, host, port):
self.host = host
Expand Down Expand Up @@ -186,6 +195,9 @@ def __repr__(self):

class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.request_num = 0
self.tainted_prefillers: list[ServerState] = []
self.tainted_decoders: list[ServerState] = []
self.node_listener = NodeListener(self)

self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
Expand Down Expand Up @@ -225,6 +237,8 @@ def abort_prefiller_request(self, server_idx: int, request_id): # Changed to sy
prefiller node.
"""
# No lock needed - atomic operation
if server_idx >= len(self.prefillers):
return
self.prefillers[server_idx].aborted_requests.add(request_id)

def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
Expand All @@ -233,6 +247,8 @@ def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to sync
This is used to release kv cache in prefiller node.
"""
# No lock needed - atomic operation
if server_idx >= len(self.prefillers):
return set()
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
self.prefillers[server_idx].aborted_requests.clear()
return aborted_requests
Expand All @@ -259,12 +275,16 @@ def select_prefiller(self, token_count): # Changed to synchronous

def release_prefiller(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.prefillers):
return
self.prefillers[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)

def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.prefillers):
return
if self.prefillers[idx].active_kv_cache > 0:
self.prefillers[idx].active_kv_cache -= token_count
# Update priority queue after releasing
Expand All @@ -287,6 +307,8 @@ def select_decoder(self, token_count): # Changed to synchronous

def release_decoder(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.decoders):
return
self.decoders[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_decoder_priority(idx)
Expand Down Expand Up @@ -317,24 +339,44 @@ async def add_instances(self, instance_type: str, instances: list[ServerState])
return added_nodes, waiting_nodes

def add_prefillers(self, instances: list[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
for server in instances:
if server in self.tainted_prefillers:
self.tainted_prefillers.remove(server)
self.prefiller_heap = [
(0, idx, server) if srv == server else (priority, idx, srv)
for priority, idx, srv in self.prefiller_heap
]
heapq.heapify(self.prefiller_heap)
elif server not in self.prefillers:
self.prefillers.append(server)
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server))
self.print_status(f"Add prefiller instances: {instances}.")

def add_decoders(self, instances: list[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
for server in instances:
if server in self.tainted_decoders:
self.tainted_decoders.remove(server)
self.decoder_heap = [
(0, idx, server) if srv == server else (priority, idx, srv)
for priority, idx, srv in self.decoder_heap
]
heapq.heapify(self.decoder_heap)
elif server not in self.decoders:
self.decoders.append(server)
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server))
self.print_status(f"Add decoder instances: {instances}.")

def remove_prefillers(self, instances: list[ServerState]) -> None:
def remove_prefillers(self, instances: list[ServerState]) -> bool:
if not instances:
return False

if self.request_num > 0:
logger.warning(f"Start to taint prefill instances {instances}.")
self._taint_prefillers(instances)
return True

instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
Expand All @@ -350,8 +392,17 @@ def remove_prefillers(self, instances: list[ServerState]) -> None:
self.prefiller_heap = prefiller_heap
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
return False

def remove_decoders(self, instances: list[ServerState]) -> bool:
if not instances:
return False

if self.request_num > 0:
logger.warning(f"Start to taint decode instances {instances}.")
self._taint_decoders(instances)
return True

def remove_decoders(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
Expand All @@ -367,6 +418,31 @@ def remove_decoders(self, instances: list[ServerState]) -> None:
self.decoder_heap = decoder_heap
heapq.heapify(self.decoder_heap)
self.print_status(f"Remove decoder instances: {instances}.")
return False

def _taint_prefillers(self, instances: list[ServerState]) -> None:
instances_to_taint = set(instances)
for server in self.prefillers:
if server in instances_to_taint and server not in self.tainted_prefillers:
self.tainted_prefillers.append(server)
Comment thread
yuxinshan marked this conversation as resolved.

self.prefiller_heap = [
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
for priority, idx, srv in self.prefiller_heap
]
heapq.heapify(self.prefiller_heap)

def _taint_decoders(self, instances: list[ServerState]) -> None:
instances_to_taint = set(instances)
for server in self.decoders:
if server in instances_to_taint and server not in self.tainted_decoders:
self.tainted_decoders.append(server)
Comment thread
yuxinshan marked this conversation as resolved.

self.decoder_heap = [
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
for priority, idx, srv in self.decoder_heap
]
heapq.heapify(self.decoder_heap)

def print_status(self, msg: str) -> None:
status = {
Expand Down Expand Up @@ -403,6 +479,16 @@ def _node_listener(self) -> None:
self.waiting_nodes.pop(node)
else:
self.waiting_nodes[node] = (instance_type, server, check_times)

if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num:
need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers)
if not need_waiting:
self.proxy_state.tainted_prefillers.clear()

if self.proxy_state.tainted_decoders and not self.proxy_state.request_num:
need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders)
if not need_waiting:
self.proxy_state.tainted_decoders.clear()
time.sleep(global_args.waiting_retry_interval)

@staticmethod
Expand Down Expand Up @@ -623,6 +709,7 @@ class InstanceInfo:

async def _handle_completions(api: str, request: Request):
try:
proxy_state.request_num += 1
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
Expand Down Expand Up @@ -736,6 +823,8 @@ async def generate_stream():
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
finally:
proxy_state.request_num -= 1


async def _handle_adjust_instances(adjust_mode: str, request: Request):
Expand Down Expand Up @@ -763,9 +852,12 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
)
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
need_waiting = proxy_state.remove_prefillers(instances)
else:
proxy_state.remove_decoders(instances)
need_waiting = proxy_state.remove_decoders(instances)

if need_waiting:
all_msg = f"Instances {instances} are isolated and waiting to be removed."
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
Expand Down