Skip to content

Commit d2a6df3

Browse files
committed
Simplify NDCube.to_nddata
1 parent 71984db commit d2a6df3

File tree

3 files changed

+72
-49
lines changed

3 files changed

+72
-49
lines changed

changelog/892.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`~ndcube.NDCube` now accepts ``global_coords=`` and ``extra_coords=`` in the constructor of the class.

ndcube/ndcube.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848
except ImportError:
4949
pass
5050

51-
COPY = object()
52-
5351

5452
class NDCubeABC(astropy.nddata.NDDataBase):
5553

@@ -381,25 +379,27 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin):
381379
_global_coords = NDCubeLinkedDescriptor(GlobalCoords)
382380

383381
def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None,
384-
unit=None, copy=False, **kwargs):
382+
unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs):
385383

386384
super().__init__(data, wcs=wcs, uncertainty=uncertainty, mask=mask,
387-
meta=meta, unit=unit, copy=copy, **kwargs)
385+
meta=meta, unit=unit, copy=copy, psf=psf, **kwargs)
388386

389387
# Enforce that the WCS object is not None
390388
if self.wcs is None:
391389
raise TypeError("The WCS argument can not be None.")
392390

393391
# Get existing extra_coords if initializing from an NDCube
394-
if hasattr(data, "extra_coords"):
392+
if extra_coords is None and getattr(data, "extra_coords", None) is not None:
395393
extra_coords = data.extra_coords
394+
if extra_coords is not None:
396395
if copy:
397396
extra_coords = deepcopy(extra_coords)
398397
self._extra_coords = extra_coords
399398

400399
# Get existing global_coords if initializing from an NDCube
401-
if hasattr(data, "global_coords"):
400+
if global_coords is None and getattr(data, "global_coords", None) is not None:
402401
global_coords = data._global_coords
402+
if global_coords is not None:
403403
if copy:
404404
global_coords = deepcopy(global_coords)
405405
self._global_coords = global_coords
@@ -1465,24 +1465,28 @@ def fill_masked(self, fill_value, uncertainty_fill_value=None, unmask=False, fil
14651465

14661466
def to_nddata(self,
14671467
*,
1468-
data=COPY,
1469-
wcs=COPY,
1470-
uncertainty=COPY,
1471-
mask=COPY,
1472-
unit=COPY,
1473-
meta=COPY,
1474-
psf=COPY,
1475-
extra_coords=COPY,
1476-
global_coords=COPY,
1468+
data=True,
1469+
wcs=True,
1470+
uncertainty=True,
1471+
mask=True,
1472+
unit=True,
1473+
meta=True,
1474+
psf=True,
14771475
nddata_type=NDData,
14781476
**kwargs,
14791477
):
14801478
"""
1481-
Constructs new type instance with the same attribute values as this `~ndcube.NDCube`.
1482-
1483-
Attribute values can be altered on the output object by setting a kwarg with the new
1484-
value, e.g. ``data=new_data``.
1485-
Any attributes not supported by the new class (``nddata_type``), will be discarded.
1479+
Constructs a new `~astopy.nddata.NDData` instance from this object.
1480+
1481+
By default all known ``NDData`` attributes are copied (by reference) from
1482+
this object, values can be altered on the output object by
1483+
setting a kwarg with the new value, e.g. ``data=new_data``.
1484+
Custom attributes on this class can be passed by setting that
1485+
keyword to `True`, for example ``mycube.to_nddata(spam=True)``
1486+
is the equivalent of setting
1487+
``mycube.to_nddata(spam=mycube.spam)``.
1488+
Any attributes not supported by the new class
1489+
(``nddata_type``), will be discarded.
14861490
14871491
Parameters
14881492
----------
@@ -1500,16 +1504,11 @@ def to_nddata(self,
15001504
Metadata object of new instance. Default is to use data of this instance.
15011505
psf: Any, optional
15021506
PSF object of new instance. Default is to use data of this instance.
1503-
extra_coords: `ndcube.ExtraCoordsABC`, optional
1504-
Extra coords object of new instance. Default is to use data of this instance.
1505-
global_coords: `ndcube.GlobalCoordsABC`, optional
1506-
WCS object of new instance. Default is to use data of this instance.
15071507
nddata_type: Any, optional
15081508
The type of the returned object. Must be a subclass of `~astropy.nddata.NDData`
15091509
or a class that behaves like one. Default=`~astropy.nddata.NDData`.
15101510
kwargs:
1511-
Additional inputs to the ``nddata_type`` constructor that should differ from,
1512-
or are not represented by, the attributes of this instance. For example, to
1511+
Additional inputs to the ``nddata_type`` constructor. For example, to
15131512
set different data values on the returned object, set a kwarg ``data=new_data``,
15141513
where ``new_data`` is an array of compatible shape and dtype. Note that kwargs
15151514
given by the user and attributes on this instance that are not supported by the
@@ -1525,39 +1524,29 @@ def to_nddata(self,
15251524
Examples
15261525
--------
15271526
To create an `~astropy.nddata.NDData` instance which is a copy of an `~ndcube.NDCube`
1528-
(called ``cube``) without a WCS, do:
1527+
(called ``cube``) without a WCS, do::
15291528
15301529
>>> nddata_without_coords = cube.to_nddata(wcs=None) # doctest: +SKIP
1530+
1531+
To create a new `~ndcube.NDCube` instance which is a copy of
1532+
an `~ndcube.NDCube` (called ``cube``) without an uncertainty,
1533+
but with ``global_coords`` and ``extra_coords`` do::
1534+
1535+
>>> nddata_without_coords = cube.to_nddata(uncertainty=None, global_coords=True, extra_coords=True) # doctest: +SKIP
15311536
"""
1532-
# Build dictionary of new attribute values from this NDCube instance
1533-
# and update with user-defined kwargs. Remove any kwargs not set by user.
1537+
# Put all NDData kwargs in a dict
15341538
user_kwargs = {"data": data,
15351539
"wcs": wcs,
15361540
"uncertainty": uncertainty,
15371541
"mask": mask,
15381542
"unit": unit,
15391543
"meta": meta,
15401544
"psf": psf,
1541-
"extra_coords": extra_coords,
1542-
"global_coords": global_coords}
1543-
user_kwargs = {key: value for key, value in user_kwargs.items() if value is not COPY}
1544-
user_kwargs.update(kwargs)
1545-
all_kwargs = {key.strip("_"): value for key, value in self.__dict__.items()}
1546-
all_kwargs.update(user_kwargs)
1547-
# Inspect call signature of new_nddata class and
1548-
# remove unsupported items from new_kwargs.
1549-
all_kwargs = {key: value for key, value in all_kwargs.items()
1550-
if key in inspect.signature(nddata_type).parameters.keys()}
1545+
**kwargs}
1546+
# If any are True then copy by reference
1547+
user_kwargs = {key: getattr(self, key) if value is True else value for key, value in user_kwargs.items()}
15511548
# Construct and return new instance.
1552-
new_nddata = nddata_type(**all_kwargs)
1553-
if isinstance(new_nddata, NDCubeBase):
1554-
if extra_coords is COPY:
1555-
extra_coords = copy.copy(self._extra_coords)
1556-
extra_coords._ndcube = new_nddata
1557-
new_nddata._extra_coords = extra_coords
1558-
if global_coords is COPY:
1559-
new_nddata._global_coords = copy.copy(self._global_coords)
1560-
return new_nddata
1549+
return nddata_type(**user_kwargs)
15611550

15621551

15631552
def _create_masked_array_for_rebinning(data, mask, operation_ignores_mask):

ndcube/tests/test_ndcube.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime):
4444
assert ec is not ec3
4545

4646

47+
def test_initialize_with_extra_global_coords(ndcube_3d_ln_lt_l_ec_all_axes):
48+
ndc = ndcube_3d_ln_lt_l_ec_all_axes[:, :, 0]
49+
data = ndc.data
50+
wcs = ndc.wcs
51+
ec = ndc.extra_coords
52+
gc = ndc.global_coords
53+
54+
new_cube = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc)
55+
assert new_cube.extra_coords is ec
56+
assert new_cube.global_coords is gc
57+
58+
new_cube_copy = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc, copy=True)
59+
assert new_cube_copy.extra_coords is not ec
60+
assert new_cube_copy.global_coords is not gc
61+
helpers.assert_extra_coords_equal(new_cube_copy.extra_coords, ec)
62+
helpers.assert_global_coords_equal(new_cube_copy.global_coords, gc)
63+
64+
4765
def test_wcs_type_after_init(ndcube_3d_ln_lt_l, wcs_3d_l_lt_ln):
4866
# Generate a low level WCS
4967
slices = np.s_[:, :, 0]
@@ -254,8 +272,23 @@ def test_to_nddata_type_ndcube(ndcube_2d_ln_lt_uncert_ec):
254272
ndc = ndcube_2d_ln_lt_uncert_ec
255273
ndc.global_coords.add("wavelength", "em.wl", 100*u.nm)
256274
new_data = ndc.data * 2
257-
output = ndc.to_nddata(data=new_data, nddata_type=NDCube)
275+
output = ndc.to_nddata(data=new_data, extra_coords=True, global_coords=True, nddata_type=NDCube)
258276
assert type(output) is NDCube
259277
assert (output.data == new_data).all()
260278
helpers.assert_extra_coords_equal(output.extra_coords, ndc.extra_coords)
261279
helpers.assert_global_coords_equal(output.global_coords, ndc.global_coords)
280+
281+
282+
def test_custom_tonddata_type(ndcube_2d_ln_lt):
283+
ndc = ndcube_2d_ln_lt
284+
ndc.spam = "Eggs"
285+
286+
class MyNDData(astropy.nddata.NDData):
287+
def __init__(self, data, *, spam=None, **kwargs):
288+
super().__init__(data, **kwargs)
289+
self.spam = spam
290+
291+
new_ndd = ndc.to_nddata(spam=True, nddata_type=MyNDData)
292+
assert new_ndd.spam == "Eggs"
293+
assert new_ndd.data is ndc.data
294+
assert new_ndd.wcs is ndc.wcs

0 commit comments

Comments
 (0)