Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose batch APIs for oriented bounding boxes #2823

Merged
merged 10 commits into from
Jul 26, 2023
3 changes: 2 additions & 1 deletion rerun_py/rerun_sdk/rerun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"log_mesh_file",
"log_meshes",
"log_obb",
"log_obbs",
"log_path",
"log_pinhole",
"log_point",
Expand Down Expand Up @@ -95,7 +96,7 @@
)
from .log.annotation import AnnotationInfo, ClassDescription, log_annotation_context
from .log.arrow import log_arrow
from .log.bounding_box import log_obb
from .log.bounding_box import log_obb, log_obbs
from .log.camera import log_pinhole
from .log.clear import log_cleared
from .log.extension_components import log_extension_components
Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/components/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Box3DArray(pa.ExtensionArray): # type: ignore[misc]
def from_numpy(array: npt.NDArray[np.float32]) -> Box3DArray:
"""Build a `Box3DArray` from an Nx3 numpy array."""
assert array.shape[1] == 3
assert len(array) == 0 or array.shape[1] == 3
storage = pa.FixedSizeListArray.from_arrays(array.flatten(), type=Box3DType.storage_type)
# TODO(john) enable extension type wrapper
# return cast(Box3DArray, pa.ExtensionArray.from_storage(Box3DType(), storage))
Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/components/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __array__(self) -> npt.NDArray[np.float32]:
class QuaternionArray(pa.ExtensionArray): # type: ignore[misc]
def from_numpy(array: npt.NDArray[np.float32]) -> QuaternionArray:
"""Build a `QuaternionArray` from an Nx4 numpy array."""
assert array.shape[1] == 4
assert len(array) == 0 or array.shape[1] == 4
storage = pa.FixedSizeListArray.from_arrays(array.flatten(), type=QuaternionType.storage_type)
# TODO(john) enable extension type wrapper
# return cast(QuaternionArray, pa.ExtensionArray.from_storage(QuaternionType(), storage))
Expand Down
4 changes: 2 additions & 2 deletions rerun_py/rerun_sdk/rerun/components/vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class Vec2DArray(pa.ExtensionArray): # type: ignore[misc]
def from_numpy(array: npt.NDArray[np.float32]) -> Vec2DArray:
"""Build a `Vec2DArray` from an Nx2 numpy array."""
assert array.shape[1] == 2
assert len(array) == 0 or array.shape[1] == 2
storage = pa.FixedSizeListArray.from_arrays(array.flatten(), type=Vec2DType.storage_type)
# TODO(john) enable extension type wrapper
# return cast(Vec2DArray, pa.ExtensionArray.from_storage(Vec2DType(), storage))
Expand All @@ -32,7 +32,7 @@ def from_numpy(array: npt.NDArray[np.float32]) -> Vec2DArray:
class Vec3DArray(pa.ExtensionArray): # type: ignore[misc]
def from_numpy(array: npt.NDArray[np.float32]) -> Vec3DArray:
"""Build a `Vec3DArray` from an Nx3 numpy array."""
assert array.shape[1] == 3
assert len(array) == 0 or array.shape[1] == 3
storage = pa.FixedSizeListArray.from_arrays(array.flatten(), type=Vec3DType.storage_type)
# TODO(john) enable extension type wrapper
# return cast(Vec3DArray, pa.ExtensionArray.from_storage(Vec3DType(), storage))
Expand Down
130 changes: 128 additions & 2 deletions rerun_py/rerun_sdk/rerun/log/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -14,13 +14,22 @@
from rerun.components.quaternion import QuaternionArray
from rerun.components.radius import RadiusArray
from rerun.components.vec import Vec3DArray
from rerun.log import Color, _normalize_colors, _normalize_ids, _normalize_radii
from rerun.log import (
Color,
Colors,
OptionalClassIds,
_normalize_colors,
_normalize_ids,
_normalize_labels,
_normalize_radii,
)
from rerun.log.extension_components import _add_extension_components
from rerun.log.log_decorator import log_decorator
from rerun.recording_stream import RecordingStream

__all__ = [
"log_obb",
"log_obbs",
]


Expand Down Expand Up @@ -136,3 +145,120 @@ def log_obb(
# Always the primary component last so range-based queries will include the other data. See(#1215)
if instanced:
bindings.log_arrow_msg(entity_path, components=instanced, timeless=timeless, recording=recording)


@log_decorator
def log_obbs(
entity_path: str,
*,
half_sizes: npt.ArrayLike | None,
positions: npt.ArrayLike | None = None,
rotations_q: npt.ArrayLike | None = None,
colors: Color | Colors | None = None,
stroke_widths: npt.ArrayLike | None = None,
labels: Sequence[str] | None = None,
class_ids: OptionalClassIds | None = None,
ext: dict[str, Any] | None = None,
timeless: bool = False,
recording: RecordingStream | None = None,
) -> None:
"""
Log a 3D Oriented Bounding Box, or OBB.

Example:
--------
```
rr.log_obb("my_obb", half_size=[1.0, 2.0, 3.0], position=[0, 0, 0], rotation_q=[0, 0, 0, 1])
```

Parameters
----------
entity_path:
The path to the oriented bounding box in the space hierarchy.
half_sizes:
Nx3 Array. Each row is the [x, y, z] half dimensions of an OBB.
positions:
Optional Nx3 array. Each row is [x, y, z] positions of an OBB in world space.
rotations_q:
Optional Nx3 array. Each row is quaternion coordinates [x, y, z, w] for the rotation from model to world space.
colors:
Optional Nx3 or Nx4 array. Each row is RGB or RGBA in sRGB gamma-space as either 0-1 floats or 0-255 integers,
with separate alpha.
stroke_widths:
Optional array of the width of the line edges.
labels:
Optional array of text labels placed at `position`.
class_ids:
Optional array of class id for the OBBs. The class id provides colors and labels if not specified explicitly.
ext:
Optional dictionary of extension components. See [rerun.log_extension_components][]
timeless:
If true, the bounding box will be timeless (default: False).
recording:
Specifies the [`rerun.RecordingStream`][] to use.
If left unspecified, defaults to the current active data recording, if there is one.
See also: [`rerun.init`][], [`rerun.set_global_data_recording`][].

"""
recording = RecordingStream.to_native(recording)

colors = _normalize_colors(colors)
radii = _normalize_radii(stroke_widths)
radii / 2
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
labels = _normalize_labels(labels)
class_ids = _normalize_ids(class_ids)

# 0 = instanced, 1 = splat
comps = [{}, {}] # type: ignore[var-annotated]

if half_sizes is not None:
half_sizes = np.require(half_sizes, dtype="float32")

if len(half_sizes) == 0 or half_sizes.shape[1] == 3:
comps[0]["rerun.box3d"] = Box3DArray.from_numpy(half_sizes)
else:
raise TypeError("half_size should be Nx3")

if positions is not None:
positions = np.require(positions, dtype="float32")

if len(positions) == 0 or positions.shape[1] == 3:
comps[0]["rerun.vec3d"] = Vec3DArray.from_numpy(positions)
else:
raise TypeError("position should be 1x3")

if rotations_q is not None:
rotations_q = np.require(rotations_q, dtype="float32")

if len(rotations_q) == 0 or rotations_q.shape[1] == 4:
comps[0]["rerun.quaternion"] = QuaternionArray.from_numpy(rotations_q)
else:
raise TypeError("rotation should be 1x4")

if len(colors):
is_splat = len(colors.shape) == 1
if is_splat:
colors = colors.reshape(1, len(colors))
comps[is_splat]["rerun.colorrgba"] = ColorRGBAArray.from_numpy(colors)

if len(radii):
is_splat = len(radii) == 1
comps[is_splat]["rerun.radius"] = RadiusArray.from_numpy(radii)

if len(labels):
is_splat = len(labels) == 1
comps[is_splat]["rerun.label"] = LabelArray.new(labels)

if len(class_ids):
is_splat = len(class_ids) == 1
comps[is_splat]["rerun.class_id"] = ClassIdArray.from_numpy(class_ids)

if ext:
_add_extension_components(comps[0], comps[1], ext, None)

if comps[1]:
comps[1]["rerun.instance_key"] = InstanceArray.splat()
bindings.log_arrow_msg(entity_path, components=comps[1], timeless=timeless, recording=recording)

# Always the primary component last so range-based queries will include the other data. See(#1215)
bindings.log_arrow_msg(entity_path, components=comps[0], timeless=timeless, recording=recording)