Skip to content

Commit

Permalink
Merge pull request #748 from roboflow/preview-pipelines
Browse files Browse the repository at this point in the history
consumption timeout for inference pipeline processes
  • Loading branch information
PawelPeczek-Roboflow authored Oct 18, 2024
2 parents 61da195 + 0d7d2d8 commit 4090999
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 36 deletions.
14 changes: 7 additions & 7 deletions docker/config/cpu_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
)
from inference.models.utils import ROBOFLOW_MODEL_TYPES


if ENABLE_STREAM_API:
stream_manager_process = Process(
target=start,
)
stream_manager_process.start()

model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)

if ACTIVE_LEARNING_ENABLED:
Expand All @@ -39,10 +46,3 @@
model_manager.init_pingback()
interface = HttpInterface(model_manager)
app = interface.app


if ENABLE_STREAM_API:
stream_manager_process = Process(
target=start,
)
stream_manager_process.start()
13 changes: 6 additions & 7 deletions docker/config/gpu_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from inference.models.utils import ROBOFLOW_MODEL_TYPES


if ENABLE_STREAM_API:
stream_manager_process = Process(
target=start,
)
stream_manager_process.start()

model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)

if ACTIVE_LEARNING_ENABLED:
Expand All @@ -41,10 +47,3 @@
model_manager,
)
app = interface.app


if ENABLE_STREAM_API:
stream_manager_process = Process(
target=start,
)
stream_manager_process.start()
101 changes: 86 additions & 15 deletions inference/core/interfaces/stream_manager/manager_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@
import signal
import socket
import sys
import time
import uuid
from functools import partial
from multiprocessing import Process, Queue
from socketserver import BaseRequestHandler, BaseServer
from threading import Lock, Thread
from types import FrameType
from typing import Any, Dict, Optional, Tuple
from uuid import uuid4

from inference.core import logger
from inference.core.interfaces.camera.video_source import StreamState
from inference.core.interfaces.stream_manager.manager_app.communication import (
receive_socket_data,
send_data_trough_socket,
)
from inference.core.interfaces.stream_manager.manager_app.entities import (
PIPELINE_ID_KEY,
REPORT_KEY,
SOURCES_METADATA_KEY,
STATE_KEY,
STATUS_KEY,
TYPE_KEY,
CommandType,
Expand All @@ -37,7 +44,7 @@
RoboflowTCPServer,
)

PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {}
PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue, Lock]] = {}
HEADER_SIZE = 4
SOCKET_BUFFER_SIZE = 16384
HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1")
Expand All @@ -51,9 +58,9 @@ def __init__(
request: socket.socket,
client_address: Any,
server: BaseServer,
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
):
self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle()
self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes ()
super().__init__(request, client_address, server)

def handle(self) -> None:
Expand Down Expand Up @@ -159,6 +166,7 @@ def _initialise_pipeline(self, request_id: str, command: dict) -> None:
inference_pipeline_manager,
command_queue,
responses_queue,
Lock(),
)
command_queue.put((request_id, command))
response = get_response_ignoring_thrash(
Expand Down Expand Up @@ -189,6 +197,7 @@ def _start_webrtc(self, request_id: str, command: dict):
inference_pipeline_manager,
command_queue,
responses_queue,
Lock(),
)
command_queue.put((request_id, command))
response = get_response_ignoring_thrash(
Expand Down Expand Up @@ -237,7 +246,7 @@ def _terminate_pipeline(


def handle_command(
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
request_id: str,
pipeline_id: str,
command: dict,
Expand All @@ -246,13 +255,14 @@ def handle_command(
return describe_error(
exception=None,
error_type=ErrorType.NOT_FOUND,
public_error_message=f"Could not found InferencePipeline with id={pipeline_id}.",
public_error_message=f"Could not find InferencePipeline with id={pipeline_id}.",
)
_, command_queue, responses_queue, command_lock = processes_table[pipeline_id]
with command_lock:
command_queue.put((request_id, command))
return get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
)
_, command_queue, responses_queue = processes_table[pipeline_id]
command_queue.put((request_id, command))
return get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
)


def get_response_ignoring_thrash(
Expand All @@ -270,7 +280,7 @@ def get_response_ignoring_thrash(
def execute_termination(
signal_number: int,
frame: FrameType,
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]],
) -> None:
pipeline_ids = list(processes_table.keys())
for pipeline_id in pipeline_ids:
Expand All @@ -285,22 +295,83 @@ def execute_termination(


def join_inference_pipeline(
processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str
processes_table: Dict[str, Tuple[Process, Queue, Queue, Lock]], pipeline_id: str
) -> None:
inference_pipeline_manager, command_queue, responses_queue = processes_table[
pipeline_id
]
inference_pipeline_manager, *_ = processes_table[pipeline_id]
inference_pipeline_manager.join()
del processes_table[pipeline_id]


def check_process_health() -> None:
while True:
for pipeline_id, (process, *_) in list(PROCESSES_TABLE.items()):
if not process.is_alive():
logger.warning(
"Process for pipeline_id=%s is not alive. Terminating...",
pipeline_id,
)
process.terminate()
process.join()
del PROCESSES_TABLE[pipeline_id]
continue
command = {
TYPE_KEY: CommandType.STATUS,
PIPELINE_ID_KEY: pipeline_id,
}
response = handle_command(
processes_table=PROCESSES_TABLE,
request_id=uuid.uuid4().hex,
pipeline_id=pipeline_id,
command=command,
)
if (
REPORT_KEY not in response
or SOURCES_METADATA_KEY not in response[REPORT_KEY]
):
continue
all_sources_statues = set(
source_metadata[STATE_KEY]
for source_metadata in response[REPORT_KEY][SOURCES_METADATA_KEY]
if STATE_KEY in source_metadata
)
if not all_sources_statues:
continue
if all_sources_statues.issubset({StreamState.ENDED, StreamState.ERROR}):
logger.info(
"All sources depleted in pipeline %s, terminating", pipeline_id
)
command = {
TYPE_KEY: CommandType.TERMINATE,
PIPELINE_ID_KEY: pipeline_id,
}
response = handle_command(
processes_table=PROCESSES_TABLE,
request_id=uuid.uuid4().hex,
pipeline_id=pipeline_id,
command=command,
)
if not response.get(STATUS_KEY) == "success":
logger.error(
"Malformed response returned by termination command, '%s'",
response,
)
continue
process.join()
del PROCESSES_TABLE[pipeline_id]
time.sleep(1)


def start() -> None:
signal.signal(
signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE)
)
signal.signal(
signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE)
)

# check process health in daemon thread
Thread(target=check_process_health, daemon=True).start()

with RoboflowTCPServer(
server_address=(HOST, PORT),
handler_class=partial(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)

STATUS_KEY = "status"
STATE_KEY = "state"
SOURCES_METADATA_KEY = "sources_metadata"
REPORT_KEY = "report"
TYPE_KEY = "type"
ERROR_TYPE_KEY = "error_type"
REQUEST_ID_KEY = "request_id"
Expand Down Expand Up @@ -79,6 +82,7 @@ class InitialisePipelinePayload(BaseModel):
sink_configuration: MemorySinkConfiguration = MemorySinkConfiguration(
type="MemorySinkConfiguration"
)
consumption_timeout: Optional[float] = None
api_key: Optional[str] = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os
import signal
import threading
import time
from dataclasses import asdict
from functools import partial
from multiprocessing import Process, Queue
from threading import Event
from queue import Empty
from threading import Event, Lock
from types import FrameType
from typing import Dict, Optional, Tuple

Expand Down Expand Up @@ -76,17 +78,48 @@ def __init__(self, pipeline_id: str, command_queue: Queue, responses_queue: Queu
self._watchdog: Optional[PipelineWatchDog] = None
self._stop = False
self._buffer_sink: Optional[InMemoryBufferSink] = None
self._last_consume_time = (
time.monotonic()
) # Track last consume time for the pipeline
self._consumption_timeout: Optional[float] = (
None # Track zero consume timeout for the pipeline
)

def run(self) -> None:
signal.signal(signal.SIGINT, ignore_signal)
signal.signal(signal.SIGTERM, self._handle_termination_signal)

while not self._stop:
command: Optional[Tuple[str, dict]] = self._command_queue.get()
self._check_pipeline_timeout()
# Handle commands from the queue
try:
command: Optional[Tuple[str, dict]] = self._command_queue.get(timeout=1)
except Empty:
continue
if command is None:
break
request_id, payload = command
self._handle_command(request_id=request_id, payload=payload)

def _check_pipeline_timeout(self) -> None:
if self._inference_pipeline and self._consumption_timeout is not None:
time_since_last_consume = time.monotonic() - self._last_consume_time
if time_since_last_consume > self._consumption_timeout:
logger.info("Terminating pipeline due to zero consume timeout...")
try:
pid = os.getpid()
logger.info(
f"Terminating pipeline due to timeout (no consumption):{pid}..."
)
if self._inference_pipeline is not None:
self._execute_termination()
self._command_queue.put(None)
logger.info(f"Timeout Termination successful in process:{pid}...")
except Exception as error:
logger.warning(
f"Could not terminate pipeline gracefully. Error: {error}"
)

def _handle_command(self, request_id: str, payload: dict) -> None:
try:
logger.info(f"Processing request={request_id}...")
Expand Down Expand Up @@ -159,6 +192,8 @@ def _initialise_pipeline(self, request_id: str, payload: dict) -> None:
batch_collection_timeout=parsed_payload.video_configuration.batch_collection_timeout,
)
self._watchdog = watchdog
self._consumption_timeout = parsed_payload.consumption_timeout
self._last_consume_time = time.monotonic()
self._inference_pipeline.start(use_main_thread=False)
self._responses_queue.put(
(request_id, {STATUS_KEY: OperationStatus.SUCCESS})
Expand Down Expand Up @@ -232,6 +267,7 @@ def start_loop(loop: asyncio.AbstractEventLoop):
to_inference_queue=to_inference_queue,
stop_event=stop_event,
webrtc_video_transform_track=peer_connection.video_transform_track,
webrtc_peer_timeout=parsed_payload.webrtc_peer_timeout,
)

def webrtc_sink(
Expand Down Expand Up @@ -430,6 +466,7 @@ def _consume_results(self, request_id: str, payload: dict) -> None:
return None
excluded_fields = payload.get("excluded_fields")
predictions, frames = self._buffer_sink.consume_prediction()
self._last_consume_time = time.monotonic()
predictions = [
(
serialise_single_workflow_result_element(
Expand Down
18 changes: 15 additions & 3 deletions inference/core/interfaces/stream_manager/manager_app/webrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,30 +133,42 @@ def __init__(
to_inference_queue: "SyncAsyncQueue[VideoFrame]",
stop_event: Event,
webrtc_video_transform_track: VideoTransformTrack,
webrtc_peer_timeout: float,
):
self.to_inference_queue: "SyncAsyncQueue[VideoFrame]" = to_inference_queue
self._stop_event = stop_event
self._w: Optional[int] = None
self._h: Optional[int] = None
self._video_transform_track = webrtc_video_transform_track
self._is_opened = True
self.webrtc_peer_timeout = webrtc_peer_timeout

def grab(self) -> bool:
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False

self.to_inference_queue.sync_get()
try:
self.to_inference_queue.sync_get(timeout=self.webrtc_peer_timeout)
except asyncio.TimeoutError:
logger.error("Timeout while grabbing frame, considering source depleted.")
return False
return True

def retrieve(self) -> Tuple[bool, np.ndarray]:
def retrieve(self) -> Tuple[bool, Optional[np.ndarray]]:
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False, None

frame: VideoFrame = self.to_inference_queue.sync_get()
try:
frame: VideoFrame = self.to_inference_queue.sync_get(
timeout=self.webrtc_peer_timeout
)
except asyncio.TimeoutError:
logger.error("Timeout while retrieving frame, considering source depleted.")
return False, None
img = frame.to_ndarray(format="bgr24")

return True, img
Expand Down
Loading

0 comments on commit 4090999

Please sign in to comment.