Skip to content

Commit

Permalink
Added support for PyTorch array to Boxes2D's array convenience ar…
Browse files Browse the repository at this point in the history
…gument (#3719)

### What

Fixes:
* #3708

Relates to (but doesn't implement):
* #3718 

### Checklist
* [x] I have read and agree to [Contributor
Guide](https://github.com/rerun-io/rerun/blob/main/CONTRIBUTING.md) and
the [Code of
Conduct](https://github.com/rerun-io/rerun/blob/main/CODE_OF_CONDUCT.md)
* [x] ~~I've included a screenshot or gif (if applicable)~~
* [x] ~~I have tested [demo.rerun.io](https://demo.rerun.io/pr/3719) (if
applicable)~~

- [PR Build Summary](https://build.rerun.io/pr/3719)
- [Docs
preview](https://rerun.io/preview/1b9c5d8d1d5ce8b9c29176c3b8b51e6fde4faced/docs)
<!--DOCS-PREVIEW-->
- [Examples
preview](https://rerun.io/preview/1b9c5d8d1d5ce8b9c29176c3b8b51e6fde4faced/examples)
<!--EXAMPLES-PREVIEW-->
- [Recent benchmark results](https://ref.rerun.io/dev/bench/)
- [Wasm size tracking](https://ref.rerun.io/dev/sizes/)
  • Loading branch information
abey79 authored Oct 6, 2023
1 parent 37fef60 commit 87b905b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib_rerun_py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
# TODO(jleibs): understand why deps can't be installed in the same step as the wheel
shell: bash
run: |
pip install attrs>=23.1.0 numpy>=1.23 pillow pyarrow==10.0.1 pytest==7.1.2 typing_extensions>=4.5
pip install attrs>=23.1.0 numpy>=1.23 pillow pyarrow==10.0.1 pytest==7.1.2 torch==2.1.0 typing_extensions>=4.5
- name: Get version
id: get-version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/reusable_build_and_test_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ jobs:
# TODO(jleibs): understand why deps can't be installed in the same step as the wheel
shell: bash
run: |
pip install attrs>=23.1.0 numpy>=1.23 pillow pyarrow==10.0.1 pytest==7.1.2 typing_extensions>=4.5
pip install attrs>=23.1.0 numpy>=1.23 pillow pyarrow==10.0.1 pytest==7.1.2 torch==2.1.0 typing_extensions>=4.5
- name: Get version
id: get-version
Expand Down
1 change: 1 addition & 0 deletions rerun_py/requirements-lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ black==23.3.0
blackdoc==0.3.8
mypy==1.4.1
numpy>=1.24 # For mypy plugin
torch>=2.1.0
pip-check-reqs==2.4.4 # Checks for missing deps in requirements.txt files
pytest # For mypy to work
ruff==0.0.276
Expand Down
4 changes: 3 additions & 1 deletion rerun_py/rerun_sdk/rerun/archetypes/boxes2d_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def __init__(
if centers is not None:
raise ValueError("Cannot specify both `array` and `centers` at the same time.")

if type(array) is not np.ndarray:
array = np.array(array)

if np.any(array):
array = np.asarray(array, dtype="float32")
if array.ndim == 1:
array = np.expand_dims(array, axis=0)
else:
Expand Down
7 changes: 7 additions & 0 deletions rerun_py/tests/unit/common_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import numpy as np
import torch
from rerun.components import (
ClassId,
ClassIdBatch,
Expand Down Expand Up @@ -83,6 +84,8 @@ def none_empty_or_value(obj: Any, value: Any) -> Any:
np.array([1, 2, 3, 4], dtype=np.float32),
# Vec2DArrayLike: npt.NDArray[np.float32]
np.array([1, 2, 3, 4], dtype=np.float32).reshape((2, 2, 1, 1, 1)),
# PyTorch array
torch.asarray([1, 2, 3, 4], dtype=torch.float32),
]


Expand Down Expand Up @@ -118,6 +121,8 @@ def vec2ds_expected(obj: Any, type_: Any | None = None) -> Any:
np.array([1, 2, 3, 4, 5, 6], dtype=np.float32),
# Vec3DArrayLike: npt.NDArray[np.float32]
np.array([1, 2, 3, 4, 5, 6], dtype=np.float32).reshape((2, 3, 1, 1, 1)),
# PyTorch array
torch.asarray([1, 2, 3, 4, 5, 6], dtype=torch.float32),
]


Expand Down Expand Up @@ -153,6 +158,8 @@ def vec3ds_expected(obj: Any, type_: Any | None = None) -> Any:
np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32),
# Vec4DArrayLike: npt.NDArray[np.float32]
np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32).reshape((2, 4, 1, 1, 1)),
# PyTorch array
torch.asarray([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32),
]


Expand Down
16 changes: 16 additions & 0 deletions rerun_py/tests/unit/test_box2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import itertools
from typing import Optional, cast

import numpy as np
import numpy.typing as npt
import pytest
import rerun as rr
import torch
from rerun.components import (
DrawOrderLike,
HalfSizes2DBatch,
Expand Down Expand Up @@ -137,6 +140,19 @@ def test_with_array_xcycw2h2() -> None:
assert rr.Boxes2D(mins=[1, 1], sizes=[2, 4]) == rr.Boxes2D(array=[2, 3, 1, 2], array_format=rr.Box2DFormat.XCYCW2H2)


@pytest.mark.parametrize(
"array",
[
[1, 2, 3, 4],
[1, 2, 3, 4],
np.array([1, 2, 3, 4], dtype=np.float32),
torch.asarray([1, 2, 3, 4], dtype=torch.float32),
],
)
def test_with_array_types(array: npt.ArrayLike) -> None:
assert rr.Boxes2D(mins=[1, 2], sizes=[3, 4]) == rr.Boxes2D(array=array, array_format=rr.Box2DFormat.XYWH)


def test_invalid_parameter_combinations() -> None:
rr.set_strict_mode(True)

Expand Down

0 comments on commit 87b905b

Please sign in to comment.