Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Dask delayed to export large datasets to NetCDF #5391

Merged
merged 24 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5710aff
Add pof exporting row by row
jenshnielsen Sep 28, 2023
5e87a56
Add dask dependency
jenshnielsen Sep 29, 2023
cfc923e
move export logic to dataset
jenshnielsen Oct 2, 2023
e706f7f
remove list comprehension
jenshnielsen Oct 2, 2023
04513d7
use tempfile for export
jenshnielsen Oct 2, 2023
978f159
ensure partial files are closed so they can be removed
jenshnielsen Oct 2, 2023
4e1ff6d
document why we need dask
jenshnielsen Oct 2, 2023
137b51a
add basic test
jenshnielsen Oct 2, 2023
3b6b7a5
remove non working warning since feature is actually useful
jenshnielsen Oct 3, 2023
e4672dd
use tqdm for both progress bars
jenshnielsen Oct 3, 2023
55989bc
fix type checking
jenshnielsen Oct 3, 2023
fc6534d
add support for complex
jenshnielsen Oct 3, 2023
95fb0d4
Fix type checking with pyright
jenshnielsen Oct 3, 2023
c1d9780
format
jenshnielsen Oct 3, 2023
be06343
Add link to pyright issue
jenshnielsen Oct 3, 2023
4d12b63
Add log messages with attrs to exporT
jenshnielsen Oct 3, 2023
c587030
test that data is as expected
jenshnielsen Oct 3, 2023
5c5dfaf
cleanup export logic and assert log messages
jenshnielsen Oct 4, 2023
e66a1c2
Dynamically calc index
jenshnielsen Oct 4, 2023
7bc8ec1
Add changelog for 5391
jenshnielsen Oct 4, 2023
5e47782
tqdm 4.59 first version with dask support
jenshnielsen Oct 4, 2023
c9eafef
Apply suggestions from code review
jenshnielsen Oct 4, 2023
2e4b979
Apply suggestions from code review
jenshnielsen Oct 4, 2023
9f43f65
add test to ensure that normal path is taken by default
jenshnielsen Oct 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes/newsfragments/5391.new
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Large datasets are now exported to NetCDF4 using Dask delayed writer.
This avoids allocating a large amount of memory to process the whole dataset at the same time.
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ dependencies = [
"ruamel.yaml>=0.16.0,!=0.16.6",
"tabulate>=0.8.0",
"typing_extensions>=4.1.1",
"tqdm>=4.32.2",
"tqdm>=4.59.0",
"uncertainties>=3.1.4",
"versioningit>=2.0.1",
"websockets>=9.1",
"wrapt>=1.13.2",
"xarray>=2022.06.0",
"cf_xarray>=0.8.4",
"opentelemetry-api>=1.15.0",
"dask>=2022.1.0", # we are making use of xarray features that requires dask implicitly
# transitive dependencies. We list these explicitly to",
# ensure that we always use versions that do not have",
# known security vulnerabilities",
Expand Down
97 changes: 80 additions & 17 deletions qcodes/dataset/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import importlib
import json
import logging
import sys
import tempfile
import time
import uuid
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any

import numpy
from tqdm.auto import trange

import qcodes
from qcodes.dataset.data_set_protocol import (
Expand Down Expand Up @@ -76,7 +80,6 @@
)
from qcodes.utils import (
NumpyJSONEncoder,
QCoDeSDeprecationWarning,
deprecate,
issue_deprecation_warning,
)
Expand All @@ -93,6 +96,7 @@
from .exporters.export_to_xarray import (
load_to_xarray_dataarray_dict,
load_to_xarray_dataset,
xarray_to_h5netcdf_with_complex_numbers,
)
from .subscriber import _Subscriber

Expand Down Expand Up @@ -244,6 +248,7 @@ def __init__(
self._cache: DataSetCacheWithDBBackend = DataSetCacheWithDBBackend(self)
self._results: list[dict[str, VALUE]] = []
self._in_memory_cache = in_memory_cache
self._export_limit = 1000

if run_id is not None:
if not run_exists(self.conn, run_id):
Expand Down Expand Up @@ -859,7 +864,6 @@ def to_pandas_dataframe_dict(
a column and a indexed by a :py:class:`pandas.MultiIndex` formed
by the dependencies.
"""
self._warn_if_set(*params, start=start, end=end)
datadict = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -958,7 +962,6 @@ def to_pandas_dataframe(
Return a pandas DataFrame with
df = ds.to_pandas_dataframe()
"""
self._warn_if_set(*params, start=start, end=end)
datadict = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1010,7 +1013,6 @@ def to_xarray_dataarray_dict(

dataarray_dict = ds.to_xarray_dataarray_dict()
"""
self._warn_if_set(*params, start=start, end=end)
data = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1061,7 +1063,6 @@ def to_xarray_dataset(

xds = ds.to_xarray_dataset()
"""
self._warn_if_set(*params, start=start, end=end)
data = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1457,19 +1458,81 @@ def _set_export_info(self, export_info: ExportInfo) -> None:

self._export_info = export_info

@staticmethod
def _warn_if_set(
*params: str | ParamSpec | ParameterBase,
start: int | None = None,
end: int | None,
) -> None:
if len(params) > 0 or start is not None or end is not None:
QCoDeSDeprecationWarning(
"Passing params, start or stop to to_xarray_... and "
"to_pandas_... methods is deprecated "
"This will be an error in the future. "
"If you need to sub-select use `dataset.get_parameter_data`"
def _export_as_netcdf(self, path: Path, file_name: str) -> Path:
"""Export data as netcdf to a given path with file prefix"""
import xarray as xr

if self._estimate_ds_size() > self._export_limit:
file_path = path / file_name
log.info(
"Dataset is expected to be larger that threshold. Using distributed export.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
},
)
print(
"Large dataset detected. Will write to individual files and combine to reduce memory overhead."
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
log.info(
"Writing individual files to temp dir.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
"temp_dir": temp_dir,
},
)
num_files = len(self)
num_digits = len(str(num_files))
file_name_template = f"ds_{{:0{num_digits}d}}.nc"
for i in trange(num_files, desc="Writing individual files"):
xarray_to_h5netcdf_with_complex_numbers(
self.to_xarray_dataset(start=i + 1, end=i + 1),
temp_path / file_name_template.format(i),
)
files = tuple(temp_path.glob("*.nc"))
data = xr.open_mfdataset(files)
try:
log.info(
"Combining temp files into one file.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
"temp_dir": temp_dir,
},
)
xarray_to_h5netcdf_with_complex_numbers(
data, file_path, compute=False
)
finally:
data.close()
else:
file_path = super()._export_as_netcdf(path=path, file_name=file_name)
return file_path

def _estimate_ds_size(self) -> float:
"""
Give an estimated size of the dataset as the size of a single row
times the len of the dataset. Result is returned in Mega Bytes.

Note that this does not take overhead into account so it works best
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
if the row size is "large"
"""
sample_data = self.get_parameter_data(start=1, end=1)
row_size = 0.0

for param_data in sample_data.values():
for array in param_data.values():
row_size += sys.getsizeof(array)
return row_size * len(self) / 1024 / 1024


# public api
Expand Down
18 changes: 9 additions & 9 deletions qcodes/dataset/data_set_protocol.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from __future__ import annotations

import sys

if sys.version_info >= (3, 10):
# new entrypoints api was added in 3.10
from importlib.metadata import entry_points
else:
# 3.9 and earlier
from importlib_metadata import entry_points

import logging
import os
import sys
import warnings
from collections.abc import Mapping, Sequence
from enum import Enum
Expand Down Expand Up @@ -39,6 +31,13 @@
from .exporters.export_to_xarray import xarray_to_h5netcdf_with_complex_numbers
from .sqlite.queries import raw_time_to_str_time

if sys.version_info >= (3, 10):
# new entrypoints api was added in 3.10
from importlib.metadata import entry_points
else:
# 3.9 and earlier
from importlib_metadata import entry_points

if TYPE_CHECKING:
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -242,6 +241,7 @@ def get_parameter_data(
) -> ParameterData:
...


def get_parameters(self) -> SPECS:
# used by plottr
...
Expand Down
31 changes: 23 additions & 8 deletions qcodes/dataset/exporters/export_to_xarray.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Hashable, Mapping
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, cast

import numpy as np
from tqdm.dask import TqdmCallback

from qcodes.dataset.linked_datasets.links import links_to_str

Expand All @@ -23,6 +25,8 @@

from qcodes.dataset.data_set_protocol import DataSetProtocol, ParameterData

_LOG = logging.getLogger(__name__)


def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]:
# heavily inspired by xarray.core.dataset.from_dataframe
Expand Down Expand Up @@ -207,7 +211,7 @@ def _paramspec_dict_with_extras(


def xarray_to_h5netcdf_with_complex_numbers(
xarray_dataset: xr.Dataset, file_path: str | Path
xarray_dataset: xr.Dataset, file_path: str | Path, compute: bool = True
) -> None:
import cf_xarray as cfxr
from pandas import MultiIndex
Expand All @@ -230,18 +234,29 @@ def xarray_to_h5netcdf_with_complex_numbers(
internal_ds.data_vars[data_var].dtype.kind for data_var in internal_ds.data_vars
]
coord_kinds = [internal_ds.coords[coord].dtype.kind for coord in internal_ds.coords]
if "c" in data_var_kinds or "c" in coord_kinds:
allow_invalid_netcdf = "c" in data_var_kinds or "c" in coord_kinds

with warnings.catch_warnings():
# see http://xarray.pydata.org/en/stable/howdoi.html
# for how to export complex numbers
with warnings.catch_warnings():
if allow_invalid_netcdf:
warnings.filterwarnings(
"ignore",
module="h5netcdf",
message="You are writing invalid netcdf features",
category=UserWarning,
)
internal_ds.to_netcdf(
path=file_path, engine="h5netcdf", invalid_netcdf=True
)
else:
internal_ds.to_netcdf(path=file_path, engine="h5netcdf")
maybe_write_job = internal_ds.to_netcdf(
path=file_path,
engine="h5netcdf",
invalid_netcdf=allow_invalid_netcdf,
compute=compute, # pyright: ignore
)
# https://github.com/microsoft/pyright/issues/6069
if not compute and maybe_write_job is not None:
with TqdmCallback(desc="Combining files"):
_LOG.info(
"Writing netcdf file using Dask delayed writer",
extra={"file_name": file_path},
)
maybe_write_job.compute()
Loading