diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ab6ea3b1631..29d442f1eac 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -375,10 +375,11 @@ def open_dataset( scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \ - "pseudonetcdf", "zarr"}, optional + "pseudonetcdf", "zarr"} or subclass of xarray.backends.BackendEntrypoint, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - "netcdf4". + "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) + can also be used. chunks : int or dict, optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index f9790cfaffb..23e83b0021e 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -5,7 +5,7 @@ import pkg_resources -from .common import BACKEND_ENTRYPOINTS +from .common import BACKEND_ENTRYPOINTS, BackendEntrypoint STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] @@ -113,10 +113,22 @@ def guess_engine(store_spec): def get_backend(engine): - """Select open_dataset method based on current engine""" - engines = list_engines() - if engine not in engines: - raise ValueError( - f"unrecognized engine {engine} must be one of: {list(engines)}" + """Select open_dataset method based on current engine.""" + if isinstance(engine, str): + engines = list_engines() + if engine not in engines: + raise ValueError( + f"unrecognized engine {engine} must be one of: {list(engines)}" + ) + backend = engines[engine] + elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): + backend = engine + else: + raise TypeError( + ( + "engine must be a string or a subclass of " + f"xarray.backends.BackendEntrypoint: {engine}" + ) ) - return engines[engine] + + return backend diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 340495d4564..4124d0d0b81 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -1,6 +1,9 @@ +import numpy as np + +import xarray as xr from xarray.backends.api import _get_default_engine -from . import requires_netCDF4, requires_scipy +from . import assert_identical, requires_netCDF4, requires_scipy @requires_netCDF4 @@ -14,3 +17,20 @@ def test__get_default_engine(): engine_default = _get_default_engine("/example") assert engine_default == "netcdf4" + + +def test_custom_engine(): + expected = xr.Dataset( + dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s"))) + ) + + class CustomBackend(xr.backends.BackendEntrypoint): + def open_dataset( + filename_or_obj, + drop_variables=None, + **kwargs, + ): + return expected.copy(deep=True) + + actual = xr.open_dataset("fake_filename", engine=CustomBackend) + assert_identical(expected, actual)