Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/863.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable length-1 inputs to `ndcube.NDCube.crop`, not only scalars.
6 changes: 5 additions & 1 deletion ndcube/tests/test_ndcube_reproject_and_rebin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import dask.array
import numpy as np
import pytest
from specutils import Spectrum

try:
from specutils import Spectrum
except ImportError:
from specutils import Spectrum1D as Spectrum

import astropy.units as u
import astropy.wcs
Expand Down
10 changes: 10 additions & 0 deletions ndcube/tests/test_ndcube_slice_and_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def test_crop_tuple_non_tuple_input(ndcube_2d_ln_lt):
helpers.assert_cubes_equal(cropped_by_tuples, cropped_by_coords)


def test_crop_length_1_input(ndcube_2d_ln_lt):
cube = ndcube_2d_ln_lt
frame = astropy.wcs.utils.wcs_to_celestial_frame(cube.wcs)
lower_corner = SkyCoord(Tx=[0359.99667], Ty=[-0.0011111111], unit="deg", frame=frame)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these coordinates so specific?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied them from another test. Can't remember why that test had them so specific

upper_corner = SkyCoord(Tx=[[0.0044444444]], Ty=[[0.0011111111]], unit="deg", frame=frame)
cropped_by_shaped = cube.crop(lower_corner, upper_corner)
cropped_by_scalars = cube.crop((lower_corner.squeeze(),), (upper_corner.squeeze(),))
helpers.assert_cubes_equal(cropped_by_shaped, cropped_by_scalars)


def test_crop_with_nones(ndcube_4d_ln_lt_l_t):
cube = ndcube_4d_ln_lt_l_t
lower_corner = [None] * 3
Expand Down
9 changes: 6 additions & 3 deletions ndcube/utils/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

import astropy.nddata
from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS

from ndcube.utils import wcs as wcs_utils

Expand Down Expand Up @@ -81,6 +81,8 @@ def sanitize_crop_inputs(points, wcs):
# Confirm whether point contains at least one None entry.
if all(coord is None for coord in points[i]):
values_are_none[i] = True
# Squeeze length-1 coordinate objects to scalars.
points[i] = [coord.squeeze() if hasattr(coord, "squeeze") else coord for coord in points[i]]
# If no points contain a coord, i.e. if all entries in all points are None,
# set no-op flag to True and exit.
if all(values_are_none):
Expand Down Expand Up @@ -138,6 +140,8 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
# Define a list of lists to hold the array indices of the points
# where each inner list gives the index of all points for that array axis.
combined_points_array_idx = [[]] * wcs.pixel_n_dim
high_level_wcs = HighLevelWCSWrapper(wcs) if isinstance(wcs, BaseLowLevelWCS) else wcs
wcs = high_level_wcs.low_level_wcs
# For each point compute the corresponding array indices.
for point in points:
# Get the arrays axes associated with each element in point.
Expand All @@ -150,8 +154,7 @@ def get_crop_item_from_points(points, wcs, crop_by_values, keepdims):
wcs_utils.convert_between_array_and_pixel_axes(pix_axes, wcs.pixel_n_dim)))
point_inputs_array_axes = tuple(point_inputs_array_axes)
else:
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(
HighLevelWCSWrapper(wcs))
point_inputs_array_axes = wcs_utils.array_indices_for_world_objects(high_level_wcs)
# Get indices of array axes which correspond to only None inputs in point
# as well as those that correspond to a coord.
point_indices_with_inputs = []
Expand Down