Skip to content

Commit

Permalink
Cleanup internal data-structures when process has been forked (#2676)
Browse files Browse the repository at this point in the history
Closes #1921

### What

The crux of the problem is the following:
> The child process is created with a single thread—the one that called
fork(). The entire virtual address space of the parent is replicated in
the child, ...

The major consequence of this is that our global `RecordingStream`
context is duplicated into the child memory space but none of the
threads (batcher, tcp-sender, dropper, etc.) are duplicated. When we go
to call `connect()` inside the forked process, we try to replace the
global recording-stream, which subsequently tries to call drop on the
forked copy of `RecordingStreamInner` . However, without any existing
threads to process the flush, things just hang inside that flush call.

We take a few actions to alleviate this problem:
1. Introduce a new SDK function: `cleanup_if_forked` which compares the
process-ids on existing globals and forgets them as necessary.
1. In python, use `os.register_at_fork` to proactively call
`cleanup_if_forked` in any forked child processes.
1. Also add a call to `cleanup_if_forked` inside of init() in case we're
forking through a more exotic mechanism.
1. Check for the forked state anywhere we potentially flush to avoid
deadlocks and produce a visible user-error.

Additionally, it turns out that forked processes bypass the normal
python `atexit` handler which means we don't get proper shutdown/flush
behavior when the forked processes terminate. To help users workaround
this, we introduce a `@shutdown_at_exit` decorator which can be used to
decorate functions launched via multiprocessing.

### Testing

On linux:
```
$ python examples/python/multiprocessing/main.py
```
observe demo exits cleanly and all data shows in viewer.

### Checklist
* [x] I have read and agree to [Contributor
Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and
the [Code of
Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)
* [x] I've included a screenshot or gif (if applicable)
* [x] I have tested [demo.rerun.io](https://demo.rerun.io/pr/2676) (if
applicable)

- [PR Build Summary](https://build.rerun.io/pr/2676)
- [Docs
preview](https://rerun.io/preview/pr%3Ajleibs%2Fcleanup_if_forked/docs)
- [Examples
preview](https://rerun.io/preview/pr%3Ajleibs%2Fcleanup_if_forked/examples)
  • Loading branch information
jleibs authored Jul 12, 2023
1 parent 9589002 commit fdd53b3
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 8 deletions.
82 changes: 82 additions & 0 deletions crates/re_sdk/src/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,41 @@ thread_local! {
static LOCAL_BLUEPRINT_RECORDING: RefCell<Option<RecordingStream>> = RefCell::new(None);
}

/// Check whether we are the child of a fork.
///
/// If so, then our globals need to be cleaned up because they don't have associated batching
/// or sink threads. The parent of the fork will continue to process any data in the original
/// globals so nothing is being lost by doing this.
pub fn cleanup_if_forked_child() {
if let Some(global_recording) = RecordingStream::global(StoreKind::Recording) {
if global_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global Recording");
RecordingStream::forget_global(StoreKind::Recording);
}
}

if let Some(global_blueprint) = RecordingStream::global(StoreKind::Blueprint) {
if global_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting global Blueprint");
RecordingStream::forget_global(StoreKind::Recording);
}
}

if let Some(thread_recording) = RecordingStream::thread_local(StoreKind::Recording) {
if thread_recording.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local Recording");
RecordingStream::forget_thread_local(StoreKind::Recording);
}
}

if let Some(thread_blueprint) = RecordingStream::thread_local(StoreKind::Blueprint) {
if thread_blueprint.is_forked_child() {
re_log::debug!("Fork detected. Forgetting thread-local Blueprint");
RecordingStream::forget_thread_local(StoreKind::Blueprint);
}
}
}

impl RecordingStream {
/// Returns `overrides` if it exists, otherwise returns the most appropriate active recording
/// of the specified type (i.e. thread-local first, then global scope), if any.
Expand Down Expand Up @@ -106,6 +141,15 @@ impl RecordingStream {
Self::set_any(RecordingScope::Global, kind, rec)
}

/// Forgets the currently active recording of the specified type in the global scope.
///
/// WARNING: this intentionally bypasses any drop/flush logic. This should only ever be used in
/// cases where you know the batcher/sink threads have been lost such as in a forked process.
#[inline]
pub fn forget_global(kind: StoreKind) {
Self::forget_any(RecordingScope::Global, kind);
}

// --- Thread local ---

/// Returns the currently active recording of the specified type in the thread-local scope,
Expand All @@ -125,6 +169,15 @@ impl RecordingStream {
Self::set_any(RecordingScope::ThreadLocal, kind, rec)
}

/// Forgets the currently active recording of the specified type in the thread-local scope.
///
/// WARNING: this intentionally bypasses any drop/flush logic. This should only ever be used in
/// cases where you know the batcher/sink threads have been lost such as in a forked process.
#[inline]
pub fn forget_thread_local(kind: StoreKind) {
Self::forget_any(RecordingScope::ThreadLocal, kind);
}

// --- Internal helpers ---

fn get_any(scope: RecordingScope, kind: StoreKind) -> Option<RecordingStream> {
Expand Down Expand Up @@ -180,6 +233,35 @@ impl RecordingStream {
},
}
}

fn forget_any(scope: RecordingScope, kind: StoreKind) {
match kind {
StoreKind::Recording => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_DATA_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_DATA_RECORDING.with(|cell| {
if let Some(cell) = cell.take() {
std::mem::forget(cell);
}
}),
},
StoreKind::Blueprint => match scope {
RecordingScope::Global => {
if let Some(global) = GLOBAL_BLUEPRINT_RECORDING.get() {
std::mem::forget(global.write().take());
}
}
RecordingScope::ThreadLocal => LOCAL_BLUEPRINT_RECORDING.with(|cell| {
if let Some(cell) = cell.take() {
std::mem::forget(cell);
}
}),
},
}
}
}

// ---
Expand Down
2 changes: 2 additions & 0 deletions crates/re_sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub use re_log_types::{
ApplicationId, Component, ComponentName, EntityPath, SerializableComponent, StoreId, StoreKind,
};

pub use global::cleanup_if_forked_child;

#[cfg(not(target_arch = "wasm32"))]
impl crate::sink::LogSink for re_log_encoding::FileSink {
fn send(&self, msg: re_log_types::LogMsg) {
Expand Down
30 changes: 30 additions & 0 deletions crates/re_sdk/src/recording_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,17 @@ struct RecordingStreamInner {

batcher: DataTableBatcher,
batcher_to_sink_handle: Option<std::thread::JoinHandle<()>>,

pid_at_creation: u32,
}

impl Drop for RecordingStreamInner {
fn drop(&mut self) {
if self.is_forked_child() {
re_log::error_once!("Fork detected while dropping RecordingStreamInner. cleanup_if_forked() should always be called after forking. This is likely a bug in the SDK.");
return;
}

// NOTE: The command channel is private, if we're here, nothing is currently capable of
// sending data down the pipeline.
self.batcher.flush_blocking();
Expand Down Expand Up @@ -410,8 +417,14 @@ impl RecordingStreamInner {
cmds_tx,
batcher,
batcher_to_sink_handle: Some(batcher_to_sink_handle),
pid_at_creation: std::process::id(),
})
}

#[inline]
pub fn is_forked_child(&self) -> bool {
self.pid_at_creation != std::process::id()
}
}

enum Command {
Expand Down Expand Up @@ -591,6 +604,18 @@ impl RecordingStream {
pub fn store_info(&self) -> Option<&StoreInfo> {
(*self.inner).as_ref().map(|inner| &inner.info)
}

/// Determine whether a fork has happened since creating this `RecordingStream`. In general, this means our
/// batcher/sink threads are gone and all data logged since the fork has been dropped.
///
/// It is essential that [`crate::cleanup_if_forked_child`] be called after forking the process. SDK-implementations
/// should do this during their initialization phase.
#[inline]
pub fn is_forked_child(&self) -> bool {
(*self.inner)
.as_ref()
.map_or(false, |inner| inner.is_forked_child())
}
}

impl RecordingStream {
Expand Down Expand Up @@ -737,6 +762,11 @@ impl RecordingStream {
///
/// See [`RecordingStream`] docs for ordering semantics and multithreading guarantees.
pub fn flush_blocking(&self) {
if self.is_forked_child() {
re_log::error_once!("Fork detected during flush. cleanup_if_forked() should always be called after forking. This is likely a bug in the SDK.");
return;
}

let Some(this) = &*self.inner else {
re_log::warn_once!("Recording disabled - call to flush_blocking() ignored");
return;
Expand Down
19 changes: 11 additions & 8 deletions examples/python/multiprocessing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
import rerun as rr # pip install rerun-sdk


# Python does not guarantee that the normal atexit-handlers will be called at the
# termination of a multiprocessing.Process. Explicitly add the `shutdown_at_exit`
# decorator to ensure data is flushed when the task completes.
@rr.shutdown_at_exit
def task(child_index: int) -> None:
# All processes spawned with `multiprocessing` will automatically
# be assigned the same default recording_id.
# We just need to connect each process to the the rerun viewer:
# In the new process, we always need to call init with the same `application_id`.
# By default, the `recording_id`` will match the `recording_id`` of the parent process,
# so all of these processes will have their log data merged in the viewer.
# Caution: if you manually specified `recording_id` in the parent, you also must
# pass the same `recording_id` here.
rr.init("multiprocessing")

# We then have to connect to the viewer instance.
rr.connect()

title = f"task {child_index}"
Expand All @@ -37,11 +45,6 @@ def main() -> None:

task(0)

# Using multiprocessing with "fork" results in a hang on shutdown so
# always use "spawn"
# TODO(https://github.com/rerun-io/rerun/issues/1921)
multiprocessing.set_start_method("spawn")

for i in [1, 2, 3]:
p = multiprocessing.Process(target=task, args=(i,))
p.start()
Expand Down
48 changes: 48 additions & 0 deletions rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""The Rerun Python SDK, which is a wrapper around the re_sdk crate."""
from __future__ import annotations

import functools
from typing import Any, Callable, TypeVar, cast

# NOTE: The imports determine what is public API. Avoid importing globally anything that is not public API. Use
# (private) function and local import if needed.
import rerun_bindings as bindings # type: ignore[attr-defined]
Expand Down Expand Up @@ -155,6 +158,10 @@ def init(
global _strict_mode
_strict_mode = strict

# Always check whether we are a forked child when calling init. This should have happened
# via `_register_on_fork` but it's worth being conservative.
cleanup_if_forked_child()

if init_logging:
new_recording(
application_id,
Expand Down Expand Up @@ -311,6 +318,47 @@ def unregister_shutdown() -> None:
atexit.unregister(rerun_shutdown)


def cleanup_if_forked_child() -> None:
bindings.cleanup_if_forked_child()


def _register_on_fork() -> None:
# Only relevant on Linux
try:
import os

os.register_at_fork(after_in_child=cleanup_if_forked_child)
except NotImplementedError:
pass


_register_on_fork()


_TFunc = TypeVar("_TFunc", bound=Callable[..., Any])


def shutdown_at_exit(func: _TFunc) -> _TFunc:
"""
Decorator to shutdown Rerun cleanly when this function exits.
Normally, Rerun installs an atexit-handler that attempts to shutdown cleanly and
flush all outgoing data before terminating. However, some cases, such as forked
processes will always skip this at-exit handler. In these cases, you can use this
decorator on the entry-point to your subprocess to ensure cleanup happens as
expected without losing data.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return func(*args, **kwargs)
finally:
rerun_shutdown()

return cast(_TFunc, wrapper)


# ---


Expand Down
7 changes: 7 additions & 0 deletions rerun_py/src/python_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ fn rerun_bindings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(new_recording, m)?)?;
m.add_function(wrap_pyfunction!(new_blueprint, m)?)?;
m.add_function(wrap_pyfunction!(shutdown, m)?)?;
m.add_function(wrap_pyfunction!(cleanup_if_forked_child, m)?)?;

// recordings
m.add_function(wrap_pyfunction!(get_application_id, m)?)?;
Expand Down Expand Up @@ -349,6 +350,12 @@ fn get_global_data_recording() -> Option<PyRecordingStream> {
RecordingStream::global(rerun::StoreKind::Recording).map(PyRecordingStream)
}

/// Cleans up internal state if this is the child of a forked process.
#[pyfunction]
fn cleanup_if_forked_child() {
rerun::cleanup_if_forked_child();
}

/// Replaces the currently active recording in the global scope with the specified one.
///
/// Returns the previous one, if any.
Expand Down

0 comments on commit fdd53b3

Please sign in to comment.