Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow e2c e2p c2e to operate on 2D arrays. #29

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions py360convert/c2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,48 @@ def c2e(
if cube_format == "horizon":
if not isinstance(cubemap, np.ndarray):
raise TypeError('cubemap must be a numpy array for cube_format="horizon"')
if cubemap.ndim == 2:
cubemap = cubemap[..., None]
squeeze = True
else:
squeeze = False
elif cube_format == "list":
if not isinstance(cubemap, list):
raise TypeError('cubemap must be a list for cube_format="list"')
if len({x.shape for x in cubemap}) != 1:
raise ValueError("All cubemap elements must have same shape")
if cubemap[0].ndim == 2:
cubemap = [x[..., None] for x in cubemap]
squeeze = True
else:
squeeze = False
cubemap = cube_list2h(cubemap)
elif cube_format == "dict":
if not isinstance(cubemap, dict):
raise TypeError('cubemap must be a dict for cube_format="dict"')
if len({x.shape for x in cubemap.values()}) != 1:
raise ValueError("All cubemap elements must have same shape")
if cubemap["F"].ndim == 2:
cubemap = {k: v[..., None] for k, v in cubemap.items()}
squeeze = True
else:
squeeze = False
cubemap = cube_dict2h(cubemap)
elif cube_format == "dice":
if not isinstance(cubemap, np.ndarray):
raise TypeError('cubemap must be a numpy array for cube_format="dice"')
if cubemap.ndim == 2:
cubemap = cubemap[..., None]
squeeze = True
else:
squeeze = False
cubemap = cube_dice2h(cubemap)
else:
raise ValueError('Unknown cube_format "{cube_format}".')

if cubemap.ndim != 3:
raise ValueError(f"Cubemap must have 3 dimensions; got {cubemap.ndim}.")
raise ValueError(f"Cubemap must have 2 or 3 dimensions; got {cubemap.ndim}.")

if cubemap.shape[0] * 6 != cubemap.shape[1]:
raise ValueError("Cubemap's width must by 6x its height.")
if w % 8 != 0:
Expand Down Expand Up @@ -143,4 +168,4 @@ def c2e(
axis=-1,
)

return equirec
return equirec[..., 0] if squeeze else equirec
21 changes: 17 additions & 4 deletions py360convert/e2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def e2c(
Parameters
----------
e_img: ndarray
Equirectangular image in shape of [H, W, *].
Equirectangular image in shape of [H,W] or [H, W, *].
face_w: int
Length of each face of the cubemap
mode: Literal["bilinear", "nearest"]
Expand All @@ -68,8 +68,14 @@ def e2c(
Union[NDArray, list[NDArray], dict[str, NDArray]]
Cubemap in format specified by `cube_format`.
"""
if e_img.ndim != 3:
raise ValueError("e_img must have 3 dimensions.")
if e_img.ndim not in (2, 3):
raise ValueError("e_img must have 2 or 3 dimensions.")
if e_img.ndim == 2:
e_img = e_img[..., None]
squeeze = True
else:
squeeze = False

h, w = e_img.shape[:2]
if mode == "bilinear":
order = 1
Expand All @@ -89,13 +95,20 @@ def e2c(
)

if cube_format == "horizon":
pass
if squeeze:
cubemap = cubemap[..., 0]
elif cube_format == "list":
cubemap = cube_h2list(cubemap)
if squeeze:
cubemap = [x[..., 0] for x in cubemap]
elif cube_format == "dict":
cubemap = cube_h2dict(cubemap)
if squeeze:
cubemap = {k: v[..., 0] for k, v in cubemap.items()}
elif cube_format == "dice":
cubemap = cube_h2dice(cubemap)
if squeeze:
cubemap = cubemap[..., 0]
else:
raise NotImplementedError

Expand Down
14 changes: 10 additions & 4 deletions py360convert/e2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def e2p(
Parameters
----------
e_img: ndarray
Equirectangular image in shape of [H, W, *].
Equirectangular image in shape of [H,W] or [H, W, *].
fov_deg: scalar or (scalar, scalar) field of view in degree
Field of view given in float or tuple (h_fov_deg, v_fov_deg).
u_deg: horizon viewing angle in range [-180, 180]
Expand All @@ -47,8 +47,14 @@ def e2p(
np.ndarray
Perspective image.
"""
if e_img.ndim != 3:
raise ValueError("e_img must have 3 dimensions.")
if e_img.ndim not in (2, 3):
raise ValueError("e_img must have 2 or 3 dimensions.")
if e_img.ndim == 2:
e_img = e_img[..., None]
squeeze = True
else:
squeeze = False

h, w = e_img.shape[:2]

if isinstance(fov_deg, Real):
Expand All @@ -73,4 +79,4 @@ def e2p(

pers_img = np.stack([sample_equirec(e_img[..., i], coor_xy, order=order) for i in range(e_img.shape[2])], axis=-1)

return pers_img
return pers_img[..., 0] if squeeze else pers_img
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,31 @@ def dice_image():
return np.array(Image.open("assets/demo_cube.png"))


@pytest.fixture
def dice_image_mono(dice_image):
return dice_image.mean(axis=-1)


@pytest.fixture
def equirec_image():
return np.array(Image.open("assets/demo_equirec.png"))


@pytest.fixture
def equirec_image_mono(equirec_image):
return equirec_image.mean(axis=-1)


@pytest.fixture
def horizon_image():
return np.array(Image.open("assets/demo_horizon.png"))


@pytest.fixture
def horizon_image_mono(horizon_image):
return horizon_image.mean(axis=-1)


@pytest.fixture
def list_image(horizon_image):
return [
Expand All @@ -30,6 +45,11 @@ def list_image(horizon_image):
]


@pytest.fixture
def list_image_mono(list_image):
return [x.mean(axis=-1) for x in list_image]


@pytest.fixture
def dict_image(list_image):
return {
Expand All @@ -40,3 +60,8 @@ def dict_image(list_image):
"U": list_image[4],
"D": list_image[5],
}


@pytest.fixture
def dict_image_mono(dict_image):
return {k: v.mean(axis=-1) for k, v in dict_image.items()}
55 changes: 55 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def diff(x, y):
assert x.shape == y.shape
return np.abs(x.astype(float) - y.astype(float))


Expand All @@ -16,12 +17,24 @@ def test_c2e_dice(equirec_image, dice_image):
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_c2e_dice_2d(equirec_image_mono, dice_image_mono):
equirec_actual = py360convert.c2e(dice_image_mono, 512, 1024)
equirec_diff = diff(equirec_image_mono, equirec_actual)
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_c2e_horizon(equirec_image, horizon_image):
equirec_actual = py360convert.c2e(horizon_image, 512, 1024, cube_format="horizon")
equirec_diff = diff(equirec_image, equirec_actual)
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_c2e_horizon_2d(equirec_image_mono, horizon_image_mono):
equirec_actual = py360convert.c2e(horizon_image_mono, 512, 1024, cube_format="horizon")
equirec_diff = diff(equirec_image_mono, equirec_actual)
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_c2e_dict(equirec_image, dict_image):
equirec_actual = py360convert.c2e(dict_image, 512, 1024, cube_format="dict")
equirec_diff = diff(equirec_image, equirec_actual)
Expand All @@ -34,18 +47,36 @@ def test_c2e_list(equirec_image, list_image):
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_c2e_list_mono(equirec_image_mono, list_image_mono):
equirec_actual = py360convert.c2e(list_image_mono, 512, 1024, cube_format="list")
equirec_diff = diff(equirec_image_mono, equirec_actual)
assert equirec_diff.mean() < AVG_DIFF_THRESH


def test_e2c_dice(equirec_image, dice_image):
dice_actual = py360convert.e2c(equirec_image, cube_format="dice")
dice_diff = diff(dice_image, dice_actual)
assert dice_diff.mean() < AVG_DIFF_THRESH


def test_e2c_dice_mono(equirec_image_mono, dice_image_mono):
dice_actual = py360convert.e2c(equirec_image_mono, cube_format="dice")
dice_diff = diff(dice_image_mono, dice_actual)
assert dice_diff.mean() < AVG_DIFF_THRESH


def test_e2c_horizon(equirec_image, horizon_image):
horizon_actual = py360convert.e2c(equirec_image, 256, cube_format="horizon")
horizon_diff = diff(horizon_image, horizon_actual)
assert horizon_diff.mean() < AVG_DIFF_THRESH


def test_e2c_horizon_mono(equirec_image_mono, horizon_image_mono):
horizon_actual = py360convert.e2c(equirec_image_mono, 256, cube_format="horizon")
horizon_diff = diff(horizon_image_mono, horizon_actual)
assert horizon_diff.mean() < AVG_DIFF_THRESH


def test_e2c_dict(equirec_image, dict_image):
dict_actual = py360convert.e2c(equirec_image, 256, cube_format="dict")
dict_diff = {k: diff(dict_image[k], dict_actual[k]) for k in "FRBLUD"}
Expand All @@ -58,6 +89,18 @@ def test_e2c_dict(equirec_image, dict_image):
assert dict_diff["D"].mean() < AVG_DIFF_THRESH


def test_e2c_dict_mono(equirec_image_mono, dict_image_mono):
dict_actual = py360convert.e2c(equirec_image_mono, 256, cube_format="dict")
dict_diff = {k: diff(dict_image_mono[k], dict_actual[k]) for k in "FRBLUD"}

assert dict_diff["F"].mean() < AVG_DIFF_THRESH
assert dict_diff["R"].mean() < AVG_DIFF_THRESH
assert dict_diff["B"].mean() < AVG_DIFF_THRESH
assert dict_diff["L"].mean() < AVG_DIFF_THRESH
assert dict_diff["U"].mean() < AVG_DIFF_THRESH
assert dict_diff["D"].mean() < AVG_DIFF_THRESH


def test_e2c_list(equirec_image, list_image):
list_actual = py360convert.e2c(equirec_image, 256, cube_format="list")
list_diff = [diff(list_image[i], list_actual[i]) for i in range(6)]
Expand All @@ -68,3 +111,15 @@ def test_e2c_list(equirec_image, list_image):
assert list_diff[3].mean() < AVG_DIFF_THRESH
assert list_diff[4].mean() < AVG_DIFF_THRESH
assert list_diff[5].mean() < AVG_DIFF_THRESH


def test_e2c_list_mono(equirec_image_mono, list_image_mono):
list_actual = py360convert.e2c(equirec_image_mono, 256, cube_format="list")
list_diff = [diff(list_image_mono[i], list_actual[i]) for i in range(6)]

assert list_diff[0].mean() < AVG_DIFF_THRESH
assert list_diff[1].mean() < AVG_DIFF_THRESH
assert list_diff[2].mean() < AVG_DIFF_THRESH
assert list_diff[3].mean() < AVG_DIFF_THRESH
assert list_diff[4].mean() < AVG_DIFF_THRESH
assert list_diff[5].mean() < AVG_DIFF_THRESH
Loading