Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _req(req_id: str, status: RequestStatus, external_req_id: str | None = None)
prompt_token_ids=[],
num_computed_tokens=0,
additional_information=None,
is_finished=lambda: status == RequestStatus.FINISHED_STOPPED,
)


Expand All @@ -48,9 +49,9 @@ def _build(*, stage_id: int = 1, model_mode: str = "ar", max_num_seqs: int = 2):

def _fake_base_init(self, config):
self.config = config
self._pending_load_reqs = {}
self._pending_load_reqs = deque()
self._finished_load_reqs = set()
self._pending_save_reqs = {}
self._pending_save_reqs = deque()
self._finished_save_reqs = set()
self.stop_event = threading.Event()
self.lock = threading.Lock()
Expand Down Expand Up @@ -108,9 +109,8 @@ def test_load_poll(build_adapter):
adapter.load_async(request)
payload = {"code_predictor_codes": [[1]], "hidden_states": torch.tensor([[2.0]]), "finished": True}
connector.get.return_value = (payload, 16)
adapter._poll_single_request("req-1")
adapter._poll_single_request(request)

connector.get.assert_called_once_with("1", "2", "external-1_1_0")
assert request.additional_information == payload
assert adapter.get_req_chunk["req-1"] == 1
assert "req-1" in adapter._finished_load_reqs
Expand All @@ -120,17 +120,15 @@ def test_load_poll(build_adapter):

def test_save_async(build_adapter):
adapter, _ = build_adapter(stage_id=1)
request = SimpleNamespace(external_req_id="external-1")
request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1")

adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": False}
adapter.save_async(pooling_output=None, request=request)
adapter.custom_process_next_stage_input_func = lambda **kwargs: {}
adapter.save_async(pooling_output=None, request=request)

assert adapter.put_req_chunk["external-1"] == 1
queued = adapter._pending_save_reqs["external-1"]
assert len(queued) == 1
assert queued[0]["put_key"] == "external-1_1_0"
task = adapter._pending_save_reqs.popleft()
assert task["is_finished"] is False


def test_update_request_payload(build_adapter):
Expand Down
87 changes: 34 additions & 53 deletions vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import fcntl
import os
import time
from multiprocessing import shared_memory as shm_pkg
from typing import Any

from vllm_omni.entrypoints.stage_utils import shm_read_bytes, shm_write_bytes
Expand Down Expand Up @@ -51,7 +51,7 @@ def put(
if True:
# Use Shared Memory
lock_file = f"/dev/shm/shm_{put_key}_lockfile.lock"
with open(lock_file, "w") as lockf:
with open(lock_file, "wb+") as lockf:
fcntl.flock(lockf, fcntl.LOCK_EX)
meta = shm_write_bytes(payload, name=put_key)
fcntl.flock(lockf, fcntl.LOCK_UN)
Expand All @@ -75,6 +75,23 @@ def put(
logger.error(f"SharedMemoryConnector put failed for req {put_key}: {e}")
return False, 0, None

def _get_data_with_lock(self, lock_file: str, shm_handle: dict):
obj = None
try:
with open(lock_file, "rb+") as lockf:
fcntl.flock(lockf, fcntl.LOCK_EX)
data_bytes = shm_read_bytes(shm_handle)
fcntl.flock(lockf, fcntl.LOCK_UN)
obj = self.deserialize_obj(data_bytes)
return obj, int(shm_handle.get("size", 0))
except Exception as e:
logger.error(f"SharedMemoryConnector shm get failed for req : {e}")
return None, 0
finally:
# If data has been received, delete lock_file.
if obj and os.path.exists(lock_file):
os.remove(lock_file)

def get(
self,
from_stage: str,
Expand All @@ -88,71 +105,35 @@ def get(
metadata = metadata.get(get_key)

if not isinstance(metadata, dict):
return None
return None, 0

if "inline_bytes" in metadata:
try:
obj = self.deserialize_obj(metadata["inline_bytes"])
return obj, int(metadata.get("size", 0))
except Exception as e:
logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}")
return None
return None, 0

if "shm" in metadata:
try:
shm_handle = metadata["shm"]
lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock"
with open(lock_file, "w") as lockf:
fcntl.flock(lockf, fcntl.LOCK_SH)
data_bytes = shm_read_bytes(shm_handle)
fcntl.flock(lockf, fcntl.LOCK_UN)
if os.path.exists(lock_file):
os.remove(lock_file)
obj = self.deserialize_obj(data_bytes)
return obj, int(metadata.get("size", 0))
except Exception as e:
logger.error(f"SharedMemoryConnector shm get failed for req {get_key}: {e}")
return None

return None

from multiprocessing import shared_memory as shm_pkg
shm_handle = metadata["shm"]
lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock"
return self._get_data_with_lock(lock_file, shm_handle)

# Wait for shared memory to be available (with retry logic)
max_retries = 30
retry_delay = 0.1 # 100ms between retries
return None, 0
shm = None

for attempt in range(max_retries):
try:
shm = shm_pkg.SharedMemory(name=get_key)
break # Successfully opened, exit retry loop
except FileNotFoundError:
if attempt < max_retries - 1:
time.sleep(retry_delay)
else:
# Max retries reached, return None
logger.warning(f"Shared memory '{get_key}' not found after {max_retries} retries")
return None

if shm is None:
return None

try:
shm = shm_pkg.SharedMemory(name=get_key)
if shm is None or shm.size == 0:
return None, 0
lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock"
with open(lock_file) as lockf:
fcntl.flock(lockf, fcntl.LOCK_SH)
data_bytes = shm_read_bytes({"name": get_key, "size": shm.size})
fcntl.flock(lockf, fcntl.LOCK_UN)
# Clean up the temporary file if it still exists.
if os.path.exists(lock_file):
os.remove(lock_file)
obj = self.deserialize_obj(data_bytes)
return obj, shm.size
shm_handle = {"name": get_key, "size": shm.size}
return self._get_data_with_lock(lock_file, shm_handle)
except Exception:
return None, 0
finally:
shm.close()

# TODO: update another read method
if shm:
shm.close()

def cleanup(self, request_id: str) -> None:
# SHM segments are automatically unlinked during 'get' (shm_read_bytes).
Expand Down
39 changes: 16 additions & 23 deletions vllm_omni/distributed/omni_connectors/transfer_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import threading
import time
from collections import deque
from typing import Any

from ..utils.logging import get_connector_logger
Expand All @@ -22,17 +23,16 @@ def __init__(self, config: Any):
if not hasattr(self, "connector"):
self.connector = None
# Requests that are waiting to be polled
self._pending_load_reqs = {}
self._pending_load_reqs = deque()
# Requests that have successfully retrieved data
self._finished_load_reqs = set()

# Requests that are waiting to be saved
self._pending_save_reqs = {}
self._pending_save_reqs = deque()
# Requests that have successfully saved data
self._finished_save_reqs = set()

self.stop_event = threading.Event()
self.lock = threading.Lock()

self.recv_thread = threading.Thread(target=self.recv_loop, daemon=True)
self.recv_thread.start()
Expand All @@ -48,37 +48,30 @@ def recv_loop(self):
"""Loop to poll for incoming data."""
while not self.stop_event.is_set():
# Iterate over a snapshot of pending requests
with self.lock:
pending_reqs_ids = list(self._pending_load_reqs.keys())

for req_id in pending_reqs_ids:
while self._pending_load_reqs:
request = self._pending_load_reqs.popleft()
request_id = request.request_id
self.request_ids_mapping[request_id] = request.external_req_id
try:
self._poll_single_request(req_id)
is_success = self._poll_single_request(request)
if not is_success:
self._pending_load_reqs.append(request)
except Exception as e:
logger.warning(f"Error receiving data for {req_id}: {e}")
self._pending_load_reqs.append(request)
logger.warning(f"Error receiving data for {request_id}: {e}")

time.sleep(0.001)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sleep step can be triggered? When 'not ready', your code does this: it calls popleft() (which temporarily empties the queue), the poll fails, and then it immediately calls append() to put the same request back. Sleep can not be triggered forever.


def save_loop(self):
"""Loop to send outgoing data."""
while not self.stop_event.is_set():
task = None
with self.lock:
pending_save_reqs_ids = list(self._pending_save_reqs.keys())
for req_id in pending_save_reqs_ids:
if self._pending_save_reqs[req_id]:
task = self._pending_save_reqs[req_id].popleft()
if not self._pending_save_reqs[req_id]:
del self._pending_save_reqs[req_id]
break

if task:
while self._pending_save_reqs:
task = self._pending_save_reqs.popleft()
try:
self._send_single_request(task)
except Exception as e:
logger.error(f"Error saving data for {task.get('request_id')}: {e}")
else:
time.sleep(0.001)
logger.warning(f"Error saving data for {task.get('request_id')}: {e}")
time.sleep(0.001)

def _poll_single_request(self, *args, **kwargs):
"""Poll connector for a single request task.
Expand Down
Loading