Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/892.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`~ndcube.NDCube` now accepts ``global_coords=`` and ``extra_coords=`` in the constructor of the class.
94 changes: 43 additions & 51 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@
except ImportError:
pass

COPY = object()


class NDCubeABC(astropy.nddata.NDDataBase):

Expand Down Expand Up @@ -381,25 +379,27 @@ class NDCubeBase(NDCubeABC, astropy.nddata.NDData, NDCubeSlicingMixin):
_global_coords = NDCubeLinkedDescriptor(GlobalCoords)

def __init__(self, data, wcs=None, uncertainty=None, mask=None, meta=None,
unit=None, copy=False, **kwargs):
unit=None, copy=False, psf=None, *, extra_coords=None, global_coords=None, **kwargs):

super().__init__(data, wcs=wcs, uncertainty=uncertainty, mask=mask,
meta=meta, unit=unit, copy=copy, **kwargs)
meta=meta, unit=unit, copy=copy, psf=psf, **kwargs)

# Enforce that the WCS object is not None
if self.wcs is None:
raise TypeError("The WCS argument can not be None.")

# Get existing extra_coords if initializing from an NDCube
if hasattr(data, "extra_coords"):
if extra_coords is None and getattr(data, "extra_coords", None) is not None:
extra_coords = data.extra_coords
if extra_coords is not None:
if copy:
extra_coords = deepcopy(extra_coords)
self._extra_coords = extra_coords

# Get existing global_coords if initializing from an NDCube
if hasattr(data, "global_coords"):
if global_coords is None and getattr(data, "global_coords", None) is not None:
global_coords = data._global_coords
if global_coords is not None:
if copy:
global_coords = deepcopy(global_coords)
self._global_coords = global_coords
Expand Down Expand Up @@ -1465,24 +1465,28 @@ def fill_masked(self, fill_value, uncertainty_fill_value=None, unmask=False, fil

def to_nddata(self,
*,
data=COPY,
wcs=COPY,
uncertainty=COPY,
mask=COPY,
unit=COPY,
meta=COPY,
psf=COPY,
extra_coords=COPY,
global_coords=COPY,
data="copy",
wcs="copy",
uncertainty="copy",
mask="copy",
unit="copy",
meta="copy",
psf="copy",
nddata_type=NDData,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only remaining question is if we want to add a copy=False kwarg here which changes the copy behaviour to be copy by value rather that the default of copy-by-reference?

**kwargs,
):
"""
Constructs new type instance with the same attribute values as this `~ndcube.NDCube`.

Attribute values can be altered on the output object by setting a kwarg with the new
value, e.g. ``data=new_data``.
Any attributes not supported by the new class (``nddata_type``), will be discarded.
Constructs a new `~astropy.nddata.NDData` instance from this object.

By default all known ``NDData`` attributes are copied (by reference) from
this object, values can be altered on the output object by
setting a kwarg with the new value, e.g. ``data=new_data``.
Custom attributes on this class can be passed by setting that
keyword to `"copy"`, for example ``mycube.to_nddata(spam="copy")``
is the equivalent of setting
``mycube.to_nddata(spam=mycube.spam)``.
Any attributes not supported by the new class
(``nddata_type``), will be discarded.

Parameters
----------
Expand All @@ -1500,20 +1504,16 @@ def to_nddata(self,
Metadata object of new instance. Default is to use data of this instance.
psf: Any, optional
PSF object of new instance. Default is to use data of this instance.
extra_coords: `ndcube.ExtraCoordsABC`, optional
Extra coords object of new instance. Default is to use data of this instance.
global_coords: `ndcube.GlobalCoordsABC`, optional
WCS object of new instance. Default is to use data of this instance.
nddata_type: Any, optional
The type of the returned object. Must be a subclass of `~astropy.nddata.NDData`
or a class that behaves like one. Default=`~astropy.nddata.NDData`.
kwargs:
Additional inputs to the ``nddata_type`` constructor that should differ from,
or are not represented by, the attributes of this instance. For example, to
Additional inputs to the ``nddata_type`` constructor. For example, to
set different data values on the returned object, set a kwarg ``data=new_data``,
where ``new_data`` is an array of compatible shape and dtype. Note that kwargs
given by the user and attributes on this instance that are not supported by the
``nddata_type`` constructor are ignored.
where ``new_data`` is an array of compatible shape and dtype.
Other keyword arguments can be specified to copy custom
attributes with the value ``"copy"``, for example
``global_coords="copy"``.

Returns
-------
Expand All @@ -1525,39 +1525,31 @@ def to_nddata(self,
Examples
--------
To create an `~astropy.nddata.NDData` instance which is a copy of an `~ndcube.NDCube`
(called ``cube``) without a WCS, do:
(called ``cube``) without a WCS, do::

>>> nddata_without_coords = cube.to_nddata(wcs=None) # doctest: +SKIP

To create a new `~ndcube.NDCube` instance which is a copy of
an `~ndcube.NDCube` (called ``cube``) without an uncertainty,
but with ``global_coords`` and ``extra_coords`` do::

>>> nddata_without_coords = cube.to_nddata(uncertainty=None, global_coords=True, extra_coords=True) # doctest: +SKIP
"""
# Build dictionary of new attribute values from this NDCube instance
# and update with user-defined kwargs. Remove any kwargs not set by user.
# Put all NDData kwargs in a dict
user_kwargs = {"data": data,
"wcs": wcs,
"uncertainty": uncertainty,
"mask": mask,
"unit": unit,
"meta": meta,
"psf": psf,
"extra_coords": extra_coords,
"global_coords": global_coords}
user_kwargs = {key: value for key, value in user_kwargs.items() if value is not COPY}
user_kwargs.update(kwargs)
all_kwargs = {key.strip("_"): value for key, value in self.__dict__.items()}
all_kwargs.update(user_kwargs)
# Inspect call signature of new_nddata class and
# remove unsupported items from new_kwargs.
all_kwargs = {key: value for key, value in all_kwargs.items()
if key in inspect.signature(nddata_type).parameters.keys()}
**kwargs}
# If any are "copy" then copy by reference
user_kwargs = {key: getattr(self, key)
if isinstance(value, str) and value == "copy" else value
for key, value in user_kwargs.items()}
# Construct and return new instance.
new_nddata = nddata_type(**all_kwargs)
if isinstance(new_nddata, NDCubeBase):
if extra_coords is COPY:
extra_coords = copy.copy(self._extra_coords)
extra_coords._ndcube = new_nddata
new_nddata._extra_coords = extra_coords
if global_coords is COPY:
new_nddata._global_coords = copy.copy(self._global_coords)
return new_nddata
return nddata_type(**user_kwargs)


def _create_masked_array_for_rebinning(data, mask, operation_ignores_mask):
Expand Down
35 changes: 34 additions & 1 deletion ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def test_initialize_from_ndcube(ndcube_3d_l_ln_lt_ectime):
assert ec is not ec3


def test_initialize_with_extra_global_coords(ndcube_3d_ln_lt_l_ec_all_axes):
ndc = ndcube_3d_ln_lt_l_ec_all_axes[:, :, 0]
data = ndc.data
wcs = ndc.wcs
ec = ndc.extra_coords
gc = ndc.global_coords

new_cube = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc)
assert new_cube.extra_coords is ec
assert new_cube.global_coords is gc

new_cube_copy = NDCube(data, wcs=wcs, extra_coords=ec, global_coords=gc, copy=True)
assert new_cube_copy.extra_coords is not ec
assert new_cube_copy.global_coords is not gc
helpers.assert_extra_coords_equal(new_cube_copy.extra_coords, ec)
helpers.assert_global_coords_equal(new_cube_copy.global_coords, gc)


def test_wcs_type_after_init(ndcube_3d_ln_lt_l, wcs_3d_l_lt_ln):
# Generate a low level WCS
slices = np.s_[:, :, 0]
Expand Down Expand Up @@ -254,8 +272,23 @@ def test_to_nddata_type_ndcube(ndcube_2d_ln_lt_uncert_ec):
ndc = ndcube_2d_ln_lt_uncert_ec
ndc.global_coords.add("wavelength", "em.wl", 100*u.nm)
new_data = ndc.data * 2
output = ndc.to_nddata(data=new_data, nddata_type=NDCube)
output = ndc.to_nddata(data=new_data, extra_coords="copy", global_coords="copy", nddata_type=NDCube)
assert type(output) is NDCube
assert (output.data == new_data).all()
helpers.assert_extra_coords_equal(output.extra_coords, ndc.extra_coords)
helpers.assert_global_coords_equal(output.global_coords, ndc.global_coords)


def test_custom_tonddata_type(ndcube_2d_ln_lt):
ndc = ndcube_2d_ln_lt
ndc.spam = "Eggs"

class MyNDData(astropy.nddata.NDData):
def __init__(self, data, *, spam=None, **kwargs):
super().__init__(data, **kwargs)
self.spam = spam

new_ndd = ndc.to_nddata(spam="copy", nddata_type=MyNDData)
assert new_ndd.spam == "Eggs"
assert new_ndd.data is ndc.data
assert new_ndd.wcs is ndc.wcs
Loading