Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,11 @@ def open_dataset(
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
"pseudonetcdf", "zarr"}, optional
"pseudonetcdf", "zarr"} or subclass of xarray.backends.BackendEntrypoint, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
"netcdf4".
"netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``)
can also be used.
chunks : int or dict, optional
If chunks is provided, it is used to load the new dataset into dask
arrays. ``chunks=-1`` loads the dataset with dask using a single
Expand Down
26 changes: 19 additions & 7 deletions xarray/backends/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pkg_resources

from .common import BACKEND_ENTRYPOINTS
from .common import BACKEND_ENTRYPOINTS, BackendEntrypoint

STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"]

Expand Down Expand Up @@ -113,10 +113,22 @@ def guess_engine(store_spec):


def get_backend(engine):
"""Select open_dataset method based on current engine"""
engines = list_engines()
if engine not in engines:
raise ValueError(
f"unrecognized engine {engine} must be one of: {list(engines)}"
"""Select open_dataset method based on current engine."""
if isinstance(engine, str):
engines = list_engines()
if engine not in engines:
raise ValueError(
f"unrecognized engine {engine} must be one of: {list(engines)}"
)
backend = engines[engine]
elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint):
backend = engine
else:
raise TypeError(
(
"engine must be a string or a subclass of "
f"xarray.backends.BackendEntrypoint: {engine}"
)
)
return engines[engine]

return backend
22 changes: 21 additions & 1 deletion xarray/tests/test_backends_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np

import xarray as xr
from xarray.backends.api import _get_default_engine

from . import requires_netCDF4, requires_scipy
from . import assert_identical, requires_netCDF4, requires_scipy


@requires_netCDF4
Expand All @@ -14,3 +17,20 @@ def test__get_default_engine():

engine_default = _get_default_engine("/example")
assert engine_default == "netcdf4"


def test_custom_engine():
expected = xr.Dataset(
dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s")))
)

class CustomBackend(xr.backends.BackendEntrypoint):
def open_dataset(
filename_or_obj,
drop_variables=None,
**kwargs,
):
return expected.copy(deep=True)

actual = xr.open_dataset("fake_filename", engine=CustomBackend)
assert_identical(expected, actual)