diff --git a/lib/iris/experimental/xarray_dataset_wrapper.py b/lib/iris/experimental/xarray_dataset_wrapper.py index 455057afad..c453adfcff 100644 --- a/lib/iris/experimental/xarray_dataset_wrapper.py +++ b/lib/iris/experimental/xarray_dataset_wrapper.py @@ -14,9 +14,10 @@ However, this is a convenient place to test, for now. """ -from collections import OrderedDict +from collections import Iterable, OrderedDict from typing import Optional +import dask.array as da import netCDF4 as nc import numpy as np import xarray @@ -177,8 +178,74 @@ def size(self): # # writing # + + # This identifies a VariableMimic as an array-like object to which LAZY + # data can be written (a dask array). + _ACCEPTS_LAZY_DATA_WRITES = True + # Iris netcdf saving can recognise this, and will "save" the lazy content + # into a VariableMimic, instead of streaming real data into it as it would + # to a normal file variable. + # Xarray itself understands dask content, and will stream it into a file. + def __setitem__(self, keys, data): - self._xr[keys] = data + # Note: assigned 'data' may be real or lazy. + # Lazy input can *stay* lazy, in at least some cases, since xarray can + # handle lazy content. + if self._indexes_address_full_shape(keys): + # This assignment replaces the entire array content. So, we can + # just write it into the xarray variable.data : this will preserve + # laziness, since xarray supports lazy data content. + self._xr.data = data + else: + if hasattr(self._xr.data, "compute"): + # Attempted *partial* overwrite of lazy content. + # Can't be done without losing the laziness (or at least, is + # not implemented at present). + # Fetch the real array content, so we can overwrite part of it. + self._xr.data = self._xr.data.compute() + if hasattr(data, "compute"): + # Assigned value is lazy : fetch real data. + data = data.compute() + # rewrite the selected portion of the array content. + self._xr[keys] = data + + def _indexes_address_full_shape(self, keys): + """Check that a set of indexing keys addresses an entire array.""" + shape = self.shape + if not isinstance(keys, Iterable): + keys = (keys,) + + # First check there aren't more keys than dims. + # N.B. this means we don't support np.newaxis / None + # N.B. *fewer* keys is ok, as remainder is then equivalent to ", ..." + ok = len(keys) <= len(shape) + + # Remove final Ellipses (which have no effect) and error non-final ones + if ok: + while keys[-1] is Ellipsis: + keys = keys[:-1] + + # What remains can *only* be slice(a,b) with suitable a+b, + # *or* 0 for a length-1 dim (nothing else being "full") + if ok: + for key, dimlen in zip(keys, shape): + if key == 0: + ok = dimlen == 1 + elif isinstance(key, slice): + start, stop, step = key.start, key.stop, key.step + if step not in (1, None): + ok = False + elif start not in (0, None): + ok = False + elif stop is not None and stop < dimlen: + ok = False + else: + # We don't support *any* other types, + # e.g. None / dropaxis / newaxis / non-final Ellipsis + ok = False + if not ok: + break + return ok class DatasetMimic(_Nc4AttrsMimic): @@ -364,7 +431,10 @@ def createVariable( attrs = {} dt_code = f"{datatype.kind}{datatype.itemsize}" use_fill = nc.default_fillvals[dt_code] - data = np.full(shape, fill_value=use_fill, dtype=datatype) + # Use a DASK array : this is perfectly reasonable and won't chew up a + # lot of space. This content is usually discarded it when assigning + # variable data afterwards, anywway. + data = da.full(shape, fill_value=use_fill, dtype=datatype) xr_var = xr.Variable(dims=dimensions, data=data, attrs=attrs) original_varname = varname diff --git a/lib/iris/fileformats/netcdf.py b/lib/iris/fileformats/netcdf.py index 4efed43db9..73dd857ded 100644 --- a/lib/iris/fileformats/netcdf.py +++ b/lib/iris/fileformats/netcdf.py @@ -3017,12 +3017,27 @@ def _lazy_stream_data(data, fill_value, fill_warn, cf_var): if is_lazy_data(data): - def store(data, cf_var, fill_value): - # Store lazy data and check whether it is masked and contains - # the fill value - target = _FillValueMaskCheckAndStoreTarget(cf_var, fill_value) - da.store([data], [target]) - return target.is_masked, target.contains_value + if hasattr(cf_var, "_ACCEPTS_LAZY_DATA_WRITES"): + # A special case : we can write a lazy array directly into the + # the output target. + # In this case, instead of performing an actual streaming + # write, we simply assign a lazy array into the variable. + # CAVEAT: this does not actually write the file, nor perform + # the expected fill-value check ... + def store(data, cf_var, fill_value): + cf_var[:] = data + return False, False + + else: + + def store(data, cf_var, fill_value): + # Store lazy data and check whether it is masked and contains + # the fill value + target = _FillValueMaskCheckAndStoreTarget( + cf_var, fill_value + ) + da.store([data], [target]) + return target.is_masked, target.contains_value else: diff --git a/lib/iris/tests/integration/experimental/test_xarray_dataset_wrapper.py b/lib/iris/tests/integration/experimental/test_xarray_dataset_wrapper.py index 48d9878007..fd41c79c69 100644 --- a/lib/iris/tests/integration/experimental/test_xarray_dataset_wrapper.py +++ b/lib/iris/tests/integration/experimental/test_xarray_dataset_wrapper.py @@ -116,6 +116,16 @@ def test_1(self): iris.save(cubes, nc_faked_xr, saver="nc") ds = nc_faked_xr.to_xarray_dataset() + # Special "lazy streaming" ensures that the main cube arrays are LAZY. + for cube in cubes: + assert hasattr(ds.variables[cube.var_name].data, "compute") + + # print('') + # for varname, var in ds.variables.items(): + # print(f'Variable {varname!s} : shape {var.shape}, type {type(var.data)}') + # print('') + + # Save the netcdf version. xr_outpath = str(Path("tmp_xr.nc").absolute()) ds.to_netcdf(xr_outpath) @@ -143,3 +153,10 @@ def capture_dump_lines(filepath_str): # save, or via iris.save --> xarray.Dataset --> xarray.Data.to_netcdf() # Compare, omitting the first line with the filename assert lines_xr_save[1:] == lines_iris_save[1:] + + # Also check that the *content* is identical. + iris_saved_reload = iris.load("tmp_iris.nc") + xr_saved_reload = iris.load("tmp_xr.nc") + iris_saved_reload = sorted(iris_saved_reload, key=lambda c: c.name()) + xr_saved_reload = sorted(xr_saved_reload, key=lambda c: c.name()) + assert xr_saved_reload == iris_saved_reload