Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically standardize dtypes for int, datetime64, and object variables #117

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions ndpyramid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def set_zarr_encoding(
codec_config: dict | None = None,
float_dtype: npt.DTypeLike | None = None,
int_dtype: npt.DTypeLike | None = None,
datetime_dtype: npt.DTypeLike | None = None,
object_dtype: npt.DTypeLike | None = None,
) -> xr.Dataset:
"""Set zarr encoding for each variable in the dataset

Expand All @@ -77,11 +79,26 @@ def set_zarr_encoding(
The default is {'id': 'zlib', 'level': 1}
float_dtype : str or dtype, optional
Dtype to cast floating point variables to
int_dtype : str or dtype, optional
Dtype to cast integer variables to
object_dtype : str or dtype, optional
Dtype to cast object variables to.
datetime_dtype : str or dtype, optional
Dtype to encode numpy.datetime64 variables as.
Time coordinates are encoded as 'int32' if cf_xarray
is able to identify the coordinates representing time,
even if `datetime_dtype` is None.


Returns
-------
ds : xr.Dataset
Output dataset with updated variable encodings

Notes
-----
The *_dtype parameters can be used to coerce variables into data types
readable by Zarr implementations in other languages.
"""
import numcodecs

Expand All @@ -93,26 +110,30 @@ def set_zarr_encoding(

time_vars = ds.cf.axes.get('T', []) + ds.cf.bounds.get('T', [])
for varname, da in ds.variables.items():
# maybe cast float type
# remove old encoding
da.encoding.clear()

# maybe cast data type
if np.issubdtype(da.dtype, np.floating) and float_dtype is not None:
da = da.astype(float_dtype)

if np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
da.encoding['dtype'] = str(float_dtype)
elif np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
da = da.astype(int_dtype)

# remove old encoding
da.encoding.clear()
da.encoding['dtype'] = str(int_dtype)
elif da.dtype == 'O' and object_dtype is not None:
da = da.astype(object_dtype)
da.encoding['dtype'] = str(object_dtype)
elif np.issubdtype(da.dtype, np.datetime64) and datetime_dtype is not None:
da.encoding['dtype'] = str(datetime_dtype)
elif varname in time_vars:
da.encoding['dtype'] = 'int32'

# update with new encoding
da.encoding['compressor'] = compressor
with contextlib.suppress(KeyError):
del da.attrs['_FillValue']
da.encoding['_FillValue'] = default_fillvals.get(da.dtype.str[-2:], None)

# TODO: handle date/time types
# set encoding for time and time_bnds
if varname in time_vars:
da.encoding['dtype'] = 'int32'
ds[varname] = da

return ds
Expand Down Expand Up @@ -145,6 +166,13 @@ def add_metadata_and_zarr_encoding(
-------
dt.DataTree
Updated data pyramid with metadata / encoding set

Notes
-----
The variables within the pyramid are coerced into data types readable by
`@carbonplan/maps`. See https://ndpyramid.readthedocs.io/en/latest/schema.html
for more information. Raise an issue in https://github.com/carbonplan/ndpyramid
if more flexibility is needed.
'''
chunks = {'x': pixels_per_tile, 'y': pixels_per_tile}
if other_chunks is not None:
Expand All @@ -160,7 +188,12 @@ def add_metadata_and_zarr_encoding(

# set dataset encoding
pyramid[slevel].ds = set_zarr_encoding(
pyramid[slevel].ds, codec_config={'id': 'zlib', 'level': 1}, float_dtype='float32'
pyramid[slevel].ds,
codec_config={'id': 'zlib', 'level': 1},
float_dtype='float32',
int_dtype='int32',
datetime_dtype='int32',
object_dtype='str',
)

# set global metadata
Expand Down
21 changes: 10 additions & 11 deletions notebooks/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,13 @@
" .squeeze()\n",
" .reset_coords([\"band\"], drop=True)\n",
")\n",
"ds2[\"climate\"] = ds2[\"climate\"].astype(\"float32\")\n",
"ds2[\"climate\"].values[ds2[\"climate\"].values == ds2[\"climate\"].values[0, 0]] = ds1[\"climate\"].values[\n",
" 0, 0\n",
"]\n",
"ds = xr.concat([ds1, ds2], pd.Index([\"tavg\", \"prec\"], name=\"band\"))\n",
"ds[\"band\"] = ds[\"band\"].astype(\"str\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(ds, levels=LEVELS, other_chunks={'band': 2}, clear_attrs=True)\n",
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_3d, consolidated=True)"
Expand Down Expand Up @@ -189,11 +186,9 @@
" )\n",
" ds_all.append(ds)\n",
"ds = xr.concat(ds_all, pd.Index(months, name=\"month\"))\n",
"ds[\"month\"] = ds[\"month\"].astype(\"int32\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(ds, levels=LEVELS, other_chunks={'month': 12}, clear_attrs=True)\n",
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_3d_1var, consolidated=True)"
Expand Down Expand Up @@ -243,14 +238,10 @@
" )\n",
" ds2_all.append(ds)\n",
"ds2 = xr.concat(ds2_all, pd.Index(months, name=\"month\"))\n",
"ds1[\"month\"] = ds1[\"month\"].astype(\"int32\")\n",
"ds2[\"month\"] = ds2[\"month\"].astype(\"int32\")\n",
"ds2[\"climate\"] = ds2[\"climate\"].astype(\"float32\")\n",
"ds2[\"climate\"].values[ds2[\"climate\"].values == ds2[\"climate\"].values[0, 0, 0]] = ds1[\n",
" \"climate\"\n",
"].values[0, 0, 0]\n",
"ds = xr.concat([ds1, ds2], pd.Index([\"tavg\", \"prec\"], name=\"band\"))\n",
"ds[\"band\"] = ds[\"band\"].astype(\"str\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(\n",
Expand All @@ -259,8 +250,16 @@
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_4d, consolidated=True)"
"dt.to_zarr(store_4d, consolidated=True, mode=\"w\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -274,7 +273,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
Loading