Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/863.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable length-1 inputs to `ndcube.NDCube.crop`, not only scalars.
10 changes: 10 additions & 0 deletions ndcube/tests/test_ndcube_slice_and_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def test_crop_tuple_non_tuple_input(ndcube_2d_ln_lt):
helpers.assert_cubes_equal(cropped_by_tuples, cropped_by_coords)


def test_crop_length_1_input(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
lower_corner = SkyCoord(Tx=[0359.99667], Ty=[-0.0011111111], unit="deg", frame=frame)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these coordinates so specific?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied them from another test. Can't remember why that test had them so specific

upper_corner = SkyCoord(Tx=[[0.0044444444]], Ty=[[0.0011111111]], unit="deg", frame=frame)
cropped_by_shaped = cube.crop(lower_corner, upper_corner)
cropped_by_scalars = cube.crop((lower_corner.squeeze(),), (upper_corner.squeeze(),))
helpers.assert_cubes_equal(cropped_by_shaped, cropped_by_scalars)


def test_crop_with_nones(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
lower_corner = [None] * 3
Expand Down
16 changes: 13 additions & 3 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import numpy as np

import astropy.nddata
from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS
import astropy.units as u
from astropy.coordinates import SkyCoord, SpectralCoord
from astropy.time import Time
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS

from ndcube.utils import wcs as wcs_utils

Expand Down Expand Up @@ -138,8 +141,16 @@
# Define a list of lists to hold the array indices of the points
# where each inner list gives the index of all points for that array axis.
combined_points_array_idx = [[]] * wcs.pixel_n_dim
high_level_wcs = HighLevelWCSWrapper(wcs) if isinstance(wcs, BaseLowLevelWCS) else wcs
wcs = high_level_wcs.low_level_wcs
# For each point compute the corresponding array indices.
for point in points:
# Sanitize input format
# Make point a tuple if given as a single high level coord object valid for this WCS.
if isinstance(point, tuple(v[0] for v in wcs.world_axis_object_classes.values())):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activates APE-14 lawyer mode

This is technically invalid if wcs.seralized_classes is True, as the first element of world_axis_object_classes will be a string. I've never actually seen a WCS in the wild which uses seralized_classes so we could just throw an error here if it's True?

Copy link
Member Author

@DanRyanIrish DanRyanIrish Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would simply not supporting this sanitization this check when wcs.seralized_classes is True be sufficient? i.e.

Suggested change
if isinstance(point, tuple(v[0] for v in wcs.world_axis_object_classes.values())):
if not isinstance(point, (tuple, list)) and not wcs.seralized_classes and isinstance(point, tuple(v[0] for v in wcs.world_axis_object_classes.values())):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(point, tuple(v[0] for v in wcs.world_axis_object_classes.values())):
if not isinstance(point, tuple) and isinstance(point, tuple(v[0] for v in wcs.world_axis_object_classes.values())):

point = (point,)

Check warning on line 151 in ndcube/utils/cube.py

View check run for this annotation

Codecov / codecov/patch

ndcube/utils/cube.py#L151

Added line #L151 was not covered by tests
# If point is a length-1 object, convert it to scalar.
point = tuple(p.squeeze() if hasattr(p, "squeeze") else p for p in point)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the tuple thing right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Here, we can assume that point is a tuple or list. So this line is about converting length-1 objects to scalar ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little lost why we need to do this, can't we handle multi-dimensional inputs?

# Get the arrays axes associated with each element in point.
if crop_by_values:
point_inputs_array_axes = []
Expand All @@ -150,8 +161,7 @@
wcs_utils.convert_between_array_and_pixel_axes(pix_axes, wcs.pixel_n_dim)))
point_inputs_array_axes = tuple(point_inputs_array_axes)
else:
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(
HighLevelWCSWrapper(wcs))
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(high_level_wcs)
# Get indices of array axes which correspond to only None inputs in point
# as well as those that correspond to a coord.
point_indices_with_inputs = []
Expand Down
Loading