Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new API to send a dataframe to Rerun #8461

Merged
merged 8 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions crates/store/re_chunk_store/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ impl Ord for TimeColumnDescriptor {
}

impl TimeColumnDescriptor {
fn metadata(&self) -> arrow2::datatypes::Metadata {
let Self {
timeline,
datatype: _,
} = self;

std::iter::once(Some((
"sorbet.index_name".to_owned(),
timeline.name().to_string(),
)))
.flatten()
.collect()
}

#[inline]
// Time column must be nullable since static data doesn't have a time.
pub fn to_arrow_field(&self) -> Arrow2Field {
Expand All @@ -113,6 +127,7 @@ impl TimeColumnDescriptor {
datatype.clone(),
true, /* nullable */
)
.with_metadata(self.metadata())
}
}

Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/any_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from rerun._baseclasses import ComponentDescriptor

from . import ComponentColumn
from ._baseclasses import ComponentColumn
from ._log import AsComponents, ComponentBatchLike
from .error_utils import catch_and_log_exceptions

Expand Down
100 changes: 100 additions & 0 deletions rerun_py/rerun_sdk/rerun/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Optional

import pyarrow as pa
from rerun_bindings import (
ComponentColumnDescriptor as ComponentColumnDescriptor,
ComponentColumnSelector as ComponentColumnSelector,
Expand All @@ -18,3 +22,99 @@
ComponentLike as ComponentLike,
ViewContentsLike as ViewContentsLike,
)

from ._baseclasses import ComponentColumn, ComponentDescriptor
from ._log import IndicatorComponentBatch
from ._send_columns import TimeColumnLike, send_columns
from .recording_stream import RecordingStream

SORBET_INDEX_NAME = b"sorbet.index_name"
SORBET_ENTITY_PATH = b"sorbet.path"
SORBET_ARCHETYPE_NAME = b"sorbet.semantic_family"
SORBET_ARCHETYPE_FIELD = b"sorbet.logical_type"
SORBET_COMPONENT_NAME = b"sorbet.semantic_type"
RERUN_KIND = b"rerun.kind"
RERUN_KIND_CONTROL = b"control"
RERUN_KIND_INDEX = b"time"


class RawIndexColumn(TimeColumnLike):
def __init__(self, metadata: dict[bytes, bytes], col: pa.Array):
self.metadata = metadata
self.col = col

def timeline_name(self) -> str:
name = self.metadata.get(SORBET_INDEX_NAME, "unknown")
if isinstance(name, bytes):
name = name.decode("utf-8")
return name

def as_arrow_array(self) -> pa.Array:
return self.col


class RawComponentBatchLike(ComponentColumn):
def __init__(self, metadata: dict[bytes, bytes], col: pa.Array):
self.metadata = metadata
self.col = col

def component_descriptor(self) -> ComponentDescriptor:
kwargs = {}
if SORBET_ARCHETYPE_NAME in self.metadata:
kwargs["archetype_name"] = "rerun.archetypes" + self.metadata[SORBET_ARCHETYPE_NAME].decode("utf-8")
if SORBET_COMPONENT_NAME in self.metadata:
kwargs["component_name"] = "rerun.components." + self.metadata[SORBET_COMPONENT_NAME].decode("utf-8")
if SORBET_ARCHETYPE_FIELD in self.metadata:
kwargs["archetype_field_name"] = self.metadata[SORBET_ARCHETYPE_FIELD].decode("utf-8")

if "component_name" not in kwargs:
kwargs["component_name"] = "Unknown"

return ComponentDescriptor(**kwargs)

def as_arrow_array(self) -> pa.Array:
return self.col


def send_record_batch(batch: pa.RecordBatch, rec: Optional[RecordingStream] = None) -> None:
"""Coerce a single pyarrow `RecordBatch` to Rerun structure."""

indexes = []
data: defaultdict[str, list[Any]] = defaultdict(list)
archetypes: defaultdict[str, set[Any]] = defaultdict(set)
for col in batch.schema:
metadata = col.metadata or {}
if metadata.get(RERUN_KIND) == RERUN_KIND_CONTROL:
continue
if SORBET_INDEX_NAME in metadata or metadata.get(RERUN_KIND) == RERUN_KIND_INDEX:
if SORBET_INDEX_NAME not in metadata:
metadata[SORBET_INDEX_NAME] = col.name
indexes.append(RawIndexColumn(metadata, batch.column(col.name)))
else:
entity_path = metadata.get(SORBET_ENTITY_PATH, col.name.split(":")[0])
if isinstance(entity_path, bytes):
entity_path = entity_path.decode("utf-8")
data[entity_path].append(RawComponentBatchLike(metadata, batch.column(col.name)))
if SORBET_ARCHETYPE_NAME in metadata:
archetypes[entity_path].add(metadata[SORBET_ARCHETYPE_NAME].decode("utf-8"))
for entity_path, archetype_set in archetypes.items():
for archetype in archetype_set:
data[entity_path].append(IndicatorComponentBatch("rerun.archetypes." + archetype))

for entity_path, columns in data.items():
send_columns(
entity_path,
indexes,
columns,
# This is fine, send_columns will handle the conversion
recording=rec, # NOLINT
)


def send_dataframe(df: pa.RecordBatchReader | pa.Table, rec: Optional[RecordingStream] = None) -> None:
"""Coerce a pyarrow `RecordBatchReader` or `Table` to Rerun structure."""
if isinstance(df, pa.Table):
df = df.to_reader()

for batch in df:
send_record_batch(batch, rec)
16 changes: 16 additions & 0 deletions rerun_py/tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,19 @@ def test_view_syntax(self) -> None:
table = pa.Table.from_batches(batches, batches.schema)
assert table.num_columns == 3
assert table.num_rows == 0

def test_roundtrip_send(self) -> None:
df = self.recording.view(index="my_index", contents="/**").select().read_all()

with tempfile.TemporaryDirectory() as tmpdir:
rrd = tmpdir + "/tmp.rrd"

rr.init("rerun_example_test_recording")
rr.dataframe.send_dataframe(df)
rr.save(rrd)

round_trip_recording = rr.dataframe.load_recording(rrd)

df_round_trip = round_trip_recording.view(index="my_index", contents="/**").select().read_all()

assert df == df_round_trip
Loading