Skip to content

Commit

Permalink
Add a helper decorator to do the right thing with recording stream co…
Browse files Browse the repository at this point in the history
…ntext managers
  • Loading branch information
jleibs committed May 6, 2024
1 parent 84969ba commit 935f63c
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
1 change: 1 addition & 0 deletions rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
get_thread_local_data_recording,
is_enabled,
new_recording,
recording_stream_generator_ctx,
set_global_data_recording,
set_thread_local_data_recording,
thread_local_stream,
Expand Down
91 changes: 89 additions & 2 deletions rerun_py/rerun_sdk/rerun/recording_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextvars
import functools
import inspect
import uuid
Expand Down Expand Up @@ -136,6 +137,14 @@ def new_recording(
return recording


active_recording_stream: contextvars.ContextVar[RecordingStream] = contextvars.ContextVar("active_recording_stream")
"""
A context variable that tracks the currently active recording stream.
Used to managed and detect interactions between generators and RecordingStream context-manager objects.
"""


class RecordingStream:
"""
A RecordingStream is used to send data to Rerun.
Expand Down Expand Up @@ -202,14 +211,28 @@ class RecordingStream:
def __init__(self, inner: bindings.PyRecordingStream) -> None:
self.inner = inner
self._prev: RecordingStream | None = None
self.context_token: contextvars.Token[RecordingStream] | None = None

def __enter__(self): # type: ignore[no-untyped-def]
self.context_token = active_recording_stream.set(self)
self._prev = set_thread_local_data_recording(self)
return self

def __exit__(self, type, value, traceback): # type: ignore[no-untyped-def]
current_recording = active_recording_stream.get(None)

if self.context_token is not None:
active_recording_stream.reset(self.context_token)

self._prev = set_thread_local_data_recording(self._prev) # type: ignore[arg-type]

# Sanity check: we set this context-var on enter. If it's not still set, something weird
# happened. The user is probably doing something sketch with generators or async code.
if current_recording is not self:
raise RuntimeError(
"RecordingStream context manager exited while not active. Likely mixing context managers with generators or async code. See: `recording_stream_generator_ctx`."
)

# NOTE: The type is a string because we cannot reference `RecordingStream` yet at this point.
def to_native(self: RecordingStream | None) -> bindings.PyRecordingStream | None:
return self.inner if self is not None else None
Expand Down Expand Up @@ -450,9 +473,9 @@ def generator_wrapper(*args: Any, **kwargs: Any) -> Any:
with stream:
value = next(gen) # Start the generator inside the context
while True:
cont = yield value # Continue the generator
cont = yield value # Yield the value, suspending the generator
with stream:
value = gen.send(cont)
value = gen.send(cont) # Resume the generator inside the context
except StopIteration:
pass
finally:
Expand All @@ -470,3 +493,67 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper # type: ignore[return-value]

return decorator


def recording_stream_generator_ctx(func: _TFunc) -> _TFunc:
"""
Decorator to manage recording stream context for generator functions.
This is only necessary if you need to implement a generator which yields while holding an open
recording stream context. This decorator will ensure that the recording stream context is suspended
and then properly resumed upon re-entering the generator.
See: https://github.com/rerun-io/rerun/issues/6238 for context on why this is necessary.
Example
-------
```python
@rr.recording_stream.recording_stream_generator_ctx
def my_generator(name: str) -> Iterator[None]:
with rr.new_recording(name):
rr.save(f"{name}.rrd")
for i in range(10):
rr.log("stream", rr.TextLog(f"{name} {i}"))
yield i
for i in my_generator("foo"):
pass
```
"""
if inspect.isgeneratorfunction(func): # noqa: F821

@functools.wraps(func)
def generator_wrapper(*args: Any, **kwargs: Any) -> Any:
gen = func(*args, **kwargs)
current_recording = None
try:
value = next(gen) # Get the first generated value
while True:
current_recording = active_recording_stream.get(None)

if current_recording is not None:
# TODO(jleibs): Do we need to pass something through here?
# Probably not, since __exit__ doesn't use those args, but
# keep an eye on this.
current_recording.__exit__(None, None, None) # Exit our context before we yield

cont = yield value # Yield the value, suspending the generator

if current_recording is not None:
current_recording.__enter__() # Restore our context before we continue

value = gen.send(cont) # Resume the generator inside the context

except StopIteration:
pass
finally:
# It's important to re-enter the generator to avoid getting our warning
# on the final (real) exit.
if current_recording is not None:
current_recording.__enter__()
gen.close()

return generator_wrapper # type: ignore[return-value]
else:
raise ValueError("Only generator functions can be decorated with `recording_stream_generator_ctx`")

0 comments on commit 935f63c

Please sign in to comment.