Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup internal data-structures when process has been forked #2676

Merged
merged 7 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions crates/re_sdk/src/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,41 @@ thread_local! {
static LOCAL_BLUEPRINT_RECORDING: RefCell<Option<RecordingStream>> = RefCell::new(None);
}

/// Check whether we are the child of a fork.
///
/// If so, then our globals need to be cleaned up because they don't have associated batching
/// or sink threads. The parent of the fork will continue to process any data in the original
/// globals so nothing is being lost by doing this.
pub fn cleanup_if_forked_child() {
if let Some(global_recording) = RecordingStream::global(StoreKind::Recording) {
if global_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global Recording");
RecordingStream::forget_global(StoreKind::Recording);
}
}

if let Some(global_blueprint) = RecordingStream::global(StoreKind::Blueprint) {
if global_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global Blueprint");
RecordingStream::forget_global(StoreKind::Recording);
}
}

if let Some(thread_recording) = RecordingStream::thread_local(StoreKind::Recording) {
if thread_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local Recording");
RecordingStream::forget_thread_local(StoreKind::Recording);
}
}

if let Some(thread_blueprint) = RecordingStream::thread_local(StoreKind::Blueprint) {
if thread_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local Blueprint");
RecordingStream::forget_thread_local(StoreKind::Blueprint);
}
}
}

impl RecordingStream {
/// Returns `overrides` if it exists, otherwise returns the most appropriate active recording
/// of the specified type (i.e. thread-local first, then global scope), if any.
Expand Down Expand Up @@ -106,6 +141,15 @@ impl RecordingStream {
Self::set_any(RecordingScope::Global, kind, rec)
}

/// Forgets the currently active recording of the specified type in the global scope.
///
/// WARNING: this intentionally bypasses any drop/flush logic. This should only ever be used in
/// cases where you know the batcher/sink threads have been lost such as in a forked process.
#[inline]
pub fn forget_global(kind: StoreKind) {
Self::forget_any(RecordingScope::Global, kind);
}

// --- Thread local ---

/// Returns the currently active recording of the specified type in the thread-local scope,
Expand All @@ -125,6 +169,15 @@ impl RecordingStream {
Self::set_any(RecordingScope::ThreadLocal, kind, rec)
}

/// Forgets the currently active recording of the specified type in the thread-local scope.
///
/// WARNING: this intentionally bypasses any drop/flush logic. This should only ever be used in
/// cases where you know the batcher/sink threads have been lost such as in a forked process.
#[inline]
pub fn forget_thread_local(kind: StoreKind) {
Self::forget_any(RecordingScope::ThreadLocal, kind);
}

// --- Internal helpers ---

fn get_any(scope: RecordingScope, kind: StoreKind) -> Option<RecordingStream> {
Expand Down Expand Up @@ -180,6 +233,35 @@ impl RecordingStream {
},
}
}

fn forget_any(scope: RecordingScope, kind: StoreKind) {
match kind {
StoreKind::Recording => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_DATA_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_DATA_RECORDING.with(|cell| {
if let Some(cell) = cell.take() {
std::mem::forget(cell);
}
}),
},
StoreKind::Blueprint => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_BLUEPRINT_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_BLUEPRINT_RECORDING.with(|cell| {
if let Some(cell) = cell.take() {
std::mem::forget(cell);
}
}),
},
}
}
}

// ---
Expand Down
2 changes: 2 additions & 0 deletions crates/re_sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub use re_log_types::{
ApplicationId, Component, ComponentName, EntityPath, SerializableComponent, StoreId, StoreKind,
};

pub use global::cleanup_if_forked_child;

#[cfg(not(target_arch = "wasm32"))]
impl crate::sink::LogSink for re_log_encoding::FileSink {
fn send(&self, msg: re_log_types::LogMsg) {
Expand Down
30 changes: 30 additions & 0 deletions crates/re_sdk/src/recording_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,17 @@ struct RecordingStreamInner {

batcher: DataTableBatcher,
batcher_to_sink_handle: Option<std::thread::JoinHandle<()>>,

pid_at_creation: u32,
}

impl Drop for RecordingStreamInner {
fn drop(&mut self) {
if self.is_forked_child() {
re_log::error_once!("Fork detected while dropping RecordingStreamInner. cleanup_if_forked() should always be called after forking. This is likely a bug in the SDK.");
return;
}

// NOTE: The command channel is private, if we're here, nothing is currently capable of
// sending data down the pipeline.
self.batcher.flush_blocking();
Expand Down Expand Up @@ -410,8 +417,14 @@ impl RecordingStreamInner {
cmds_tx,
batcher,
batcher_to_sink_handle: Some(batcher_to_sink_handle),
pid_at_creation: std::process::id(),
})
}

#[inline]
pub fn is_forked_child(&self) -> bool {
self.pid_at_creation != std::process::id()
}
}

enum Command {
Expand Down Expand Up @@ -591,6 +604,18 @@ impl RecordingStream {
pub fn store_info(&self) -> Option<&StoreInfo> {
(*self.inner).as_ref().map(|inner| &inner.info)
}

/// Determine whether a fork has happened since creating this `RecordingStream`. In general, this means our
/// batcher/sink threads are gone and all data logged since the fork has been dropped.
///
/// It is essential that [`crate::cleanup_if_forked_child`] be called after forking the process. SDK-implementations
/// should do this during their initialization phase.
#[inline]
pub fn is_forked_child(&self) -> bool {
(*self.inner)
.as_ref()
.map_or(false, |inner| inner.is_forked_child())
}
}

impl RecordingStream {
Expand Down Expand Up @@ -737,6 +762,11 @@ impl RecordingStream {
///
/// See [`RecordingStream`] docs for ordering semantics and multithreading guarantees.
pub fn flush_blocking(&self) {
if self.is_forked_child() {
re_log::error_once!("Fork detected during flush. cleanup_if_forked() should always be called after forking. This is likely a bug in the SDK.");
return;
}

let Some(this) = &*self.inner else {
re_log::warn_once!("Recording disabled - call to flush_blocking() ignored");
return;
Expand Down
19 changes: 11 additions & 8 deletions examples/python/multiprocessing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
import rerun as rr # pip install rerun-sdk


# Python does not guarantee that the normal atexit-handlers will be called at the
# termination of a multiprocessing.Process. Explicitly add the `shutdown_at_exit`
# decorator to ensure data is flushed when the task completes.
@rr.shutdown_at_exit
def task(child_index: int) -> None:
# All processes spawned with `multiprocessing` will automatically
# be assigned the same default recording_id.
# We just need to connect each process to the the rerun viewer:
# In the new process, we always need to call init with the same `application_id`.
# By default, the `recording_id`` will match the `recording_id`` of the parent process,
# so all of these processes will have their log data merged in the viewer.
# Caution: if you manually specified `recording_id` in the parent, you also must
# pass the same `recording_id` here.
rr.init("multiprocessing")

# We then have to connect to the viewer instance.
rr.connect()

title = f"task {child_index}"
Expand All @@ -37,11 +45,6 @@ def main() -> None:

task(0)

# Using multiprocessing with "fork" results in a hang on shutdown so
# always use "spawn"
# TODO(https://github.com/rerun-io/rerun/issues/1921)
multiprocessing.set_start_method("spawn")

for i in [1, 2, 3]:
p = multiprocessing.Process(target=task, args=(i,))
p.start()
Expand Down
48 changes: 48 additions & 0 deletions rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""The Rerun Python SDK, which is a wrapper around the re_sdk crate."""
from __future__ import annotations

import functools
from typing import Any, Callable, TypeVar, cast

# NOTE: The imports determine what is public API. Avoid importing globally anything that is not public API. Use
# (private) function and local import if needed.
import rerun_bindings as bindings # type: ignore[attr-defined]
Expand Down Expand Up @@ -155,6 +158,10 @@ def init(
global _strict_mode
_strict_mode = strict

# Always check whether we are a forked child when calling init. This should have happened
# via `_register_on_fork` but it's worth being conservative.
cleanup_if_forked_child()

if init_logging:
new_recording(
application_id,
Expand Down Expand Up @@ -311,6 +318,47 @@ def unregister_shutdown() -> None:
atexit.unregister(rerun_shutdown)


def cleanup_if_forked_child() -> None:
bindings.cleanup_if_forked_child()


def _register_on_fork() -> None:
# Only relevant on Linux
try:
import os

os.register_at_fork(after_in_child=cleanup_if_forked_child)
except NotImplementedError:
pass


_register_on_fork()


_TFunc = TypeVar("_TFunc", bound=Callable[..., Any])


def shutdown_at_exit(func: _TFunc) -> _TFunc:
"""
Decorator to shutdown Rerun cleanly when this function exits.

Normally, Rerun installs an atexit-handler that attempts to shutdown cleanly and
flush all outgoing data before terminating. However, some cases, such as forked
processes will always skip this at-exit handler. In these cases, you can use this
decorator on the entry-point to your subprocess to ensure cleanup happens as
expected without losing data.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
finally:
rerun_shutdown()

return cast(_TFunc, wrapper)


# ---


Expand Down
7 changes: 7 additions & 0 deletions rerun_py/src/python_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ fn rerun_bindings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(new_recording, m)?)?;
m.add_function(wrap_pyfunction!(new_blueprint, m)?)?;
m.add_function(wrap_pyfunction!(shutdown, m)?)?;
m.add_function(wrap_pyfunction!(cleanup_if_forked_child, m)?)?;

// recordings
m.add_function(wrap_pyfunction!(get_application_id, m)?)?;
Expand Down Expand Up @@ -349,6 +350,12 @@ fn get_global_data_recording() -> Option<PyRecordingStream> {
RecordingStream::global(rerun::StoreKind::Recording).map(PyRecordingStream)
}

/// Cleans up internal state if this is the child of a forked process.
#[pyfunction]
fn cleanup_if_forked_child() {
rerun::cleanup_if_forked_child();
}

/// Replaces the currently active recording in the global scope with the specified one.
///
/// Returns the previous one, if any.
Expand Down