Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python backport: add log_any() #2581

Merged
merged 15 commits into from
Jul 3, 2023
8 changes: 4 additions & 4 deletions crates/re_types_builder/src/codegen/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ fn quote_objects(

let manifest = quote_manifest(names);
let base_include = match kind {
ObjectKind::Archetype => "from ._base import Archetype",
ObjectKind::Component => "from ._base import Component",
ObjectKind::Archetype => "from .._baseclasses import Archetype",
ObjectKind::Component => "from .._baseclasses import Component",
ObjectKind::Datatype => "",
};
code.push_unindented_text(
Expand Down Expand Up @@ -232,8 +232,8 @@ fn quote_objects(
let manifest = quote_manifest(mods.iter().flat_map(|(_, names)| names.iter()));

let (base_manifest, base_include) = match kind {
ObjectKind::Archetype => ("\"Archetype\", ", "from ._base import Archetype\n"),
ObjectKind::Component => ("\"Component\", ", "from ._base import Component\n"),
ObjectKind::Archetype => ("\"Archetype\", ", "from .._baseclasses import Archetype\n"),
ObjectKind::Component => ("\"Component\", ", "from .._baseclasses import Component\n"),
ObjectKind::Datatype => ("", ""),
};

Expand Down
43 changes: 30 additions & 13 deletions examples/python/api_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,38 @@ def run_segmentation() -> None:
rr.log_segmentation_image("seg_demo/img", segmentation_img)

# Log a bunch of classified 2D points
# Note: this uses the new, WIP object-oriented API
rr.log_any("seg_demo/single_point", rr.Points2D(np.array([64, 64]), class_ids=13))
rr.log_any("seg_demo/single_point_labeled", rr.Points2D(np.array([90, 50]), class_ids=13, labels="labeled point"))
rr.log_any("seg_demo/several_points0", rr.Points2D(np.array([[20, 50], [100, 70], [60, 30]]), class_ids=42))
rr.log_any(
"seg_demo/several_points1",
rr.Points2D(np.array([[40, 50], [120, 70], [80, 30]]), class_ids=np.array([13, 42, 99], dtype=np.uint8)),
)
rr.log_any(
"seg_demo/many points",
rr.Points2D(
if rr.ENABLE_NEXT_GEN_API:
# Note: this uses the new, WIP object-oriented API
rr.log_any("seg_demo/single_point", rr.Points2D(np.array([64, 64]), class_ids=13))
rr.log_any(
"seg_demo/single_point_labeled", rr.Points2D(np.array([90, 50]), class_ids=13, labels="labeled point")
)
rr.log_any("seg_demo/several_points0", rr.Points2D(np.array([[20, 50], [100, 70], [60, 30]]), class_ids=42))
rr.log_any(
"seg_demo/several_points1",
rr.Points2D(np.array([[40, 50], [120, 70], [80, 30]]), class_ids=np.array([13, 42, 99], dtype=np.uint8)),
)
rr.log_any(
"seg_demo/many points",
rr.Points2D(
np.array([[100 + (int(i / 5)) * 2, 100 + (i % 5) * 2] for i in range(25)]),
class_ids=np.array([42], dtype=np.uint8),
),
)
abey79 marked this conversation as resolved.
Show resolved Hide resolved
else:
rr.log_point("seg_demo/single_point", np.array([64, 64]), class_id=13)
rr.log_point("seg_demo/single_point_labeled", np.array([90, 50]), class_id=13, label="labeled point")
rr.log_points("seg_demo/several_points0", np.array([[20, 50], [100, 70], [60, 30]]), class_ids=42)
rr.log_points(
"seg_demo/several_points1",
np.array([[40, 50], [120, 70], [80, 30]]),
class_ids=np.array([13, 42, 99], dtype=np.uint8),
)
rr.log_points(
"seg_demo/many points",
np.array([[100 + (int(i / 5)) * 2, 100 + (i % 5) * 2] for i in range(25)]),
class_ids=np.array([42], dtype=np.uint8),
),
)
)

rr.log_text_entry("logs/seg_demo_log", "default colored rects, default colored points, a single point has a label")

Expand Down
5 changes: 2 additions & 3 deletions rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@
from .time import reset_time, set_time_nanos, set_time_seconds, set_time_sequence

# Next-gen API imports
# TODO(ab): remove this guard, here to make it easy to "hide" the next gen API if needed in the short term.
_ENABLE_NEXT_GEN_API = True
if _ENABLE_NEXT_GEN_API:
ENABLE_NEXT_GEN_API = True
if ENABLE_NEXT_GEN_API:
from ._rerun2.archetypes import *
from ._rerun2.log_any import log_any

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass

import pyarrow as pa


@dataclass
class Archetype:
pass


class Component(pa.ExtensionArray): # type: ignore[misc]
@property
def extension_name(self) -> str:
Expand Down
8 changes: 0 additions & 8 deletions rerun_py/rerun_sdk/rerun/_rerun2/archetypes/_base.py

This file was deleted.

49 changes: 38 additions & 11 deletions rerun_py/rerun_sdk/rerun/_rerun2/log_any.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import fields
from typing import Any
from typing import Any, Iterable

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -65,32 +65,59 @@ def _add_extension_components(
instanced[name] = pa_value # noqa


def _extract_components(entity: Archetype) -> Iterable[tuple[Component, bool]]:
"""Extract the components from an entity, yielding (component, is_primary) tuples."""
for fld in fields(entity):
if "component" in fld.metadata:
yield getattr(entity, fld.name), fld.metadata["component"] == "primary"


def log_any(
entity_path: str,
entity: Archetype,
ext: dict[str, Any] | None = None,
timeless: bool = False,
recording: RecordingStream | None = None,
) -> None:
"""
Log an entity.

Parameters
----------
entity_path:
Path to the points in the space hierarchy.
abey79 marked this conversation as resolved.
Show resolved Hide resolved
entity: Archetype
The archetype object representing the entity.
ext:
Optional dictionary of extension components. See [rerun.log_extension_components][]
timeless:
If true, the points will be timeless (default: False).
abey79 marked this conversation as resolved.
Show resolved Hide resolved
recording:
Specifies the [`rerun.RecordingStream`][] to use.
If left unspecified, defaults to the current active data recording, if there is one.
See also: [`rerun.init`][], [`rerun.set_global_data_recording`][].

"""

from .. import strict_mode

if strict_mode():
if not isinstance(entity, Archetype):
raise TypeError(f"Expected Archetype, got {type(entity)}")

# 0 = instanced, 1 = splat
instanced: dict[str, Component] = {}
splats: dict[str, Component] = {}

for fld in fields(entity):
if "component" in fld.metadata:
comp: Component = getattr(entity, fld.name)
if fld.metadata["component"] == "primary":
instanced[comp.extension_name] = comp.storage
elif len(comp) == 1:
splats[comp.extension_name] = comp.storage
elif len(comp) > 1:
instanced[comp.extension_name] = comp.storage
# find canonical length of this entity by extracting the maximum length of the
abey79 marked this conversation as resolved.
Show resolved Hide resolved
archetype_length = max(len(comp) for comp, primary in _extract_components(entity) if primary)

for comp, primary in _extract_components(entity):
if primary:
instanced[comp.extension_name] = comp.storage
elif len(comp) == 1 and archetype_length > 1:
splats[comp.extension_name] = comp.storage
elif len(comp) > 1:
instanced[comp.extension_name] = comp.storage

if ext:
_add_extension_components(instanced, splats, ext, None)
Expand Down