From d40609a22960490832ee62b64cd1a0efeb36c6c0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 27 Oct 2023 21:31:05 -0600 Subject: [PATCH] Use `opt_einsum` by default if installed. (#8373) * Use `opt_einsum` by default if installed. Closes #7764 Closes #8017 * docstring update * _ * _ Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Update xarray/core/computation.py Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * Fix docs? * Add use_opt_einsum option. * mypy ignore * one more test ignore * Disable navigation_with_keys * remove intersphinx * One more skip --------- Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- ci/install-upstream-wheels.sh | 3 ++- ci/requirements/environment.yml | 1 + doc/conf.py | 2 ++ doc/whats-new.rst | 2 ++ pyproject.toml | 3 ++- xarray/core/computation.py | 19 +++++++++++++++---- xarray/core/duck_array_ops.py | 12 +++++++++++- xarray/core/options.py | 6 ++++++ xarray/tests/test_units.py | 19 +++++++++++-------- 9 files changed, 52 insertions(+), 15 deletions(-) diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 41507fce13e..97ae4c2bbca 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -45,4 +45,5 @@ python -m pip install \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ git+https://github.com/xarray-contrib/flox \ - git+https://github.com/h5netcdf/h5netcdf + git+https://github.com/h5netcdf/h5netcdf \ + git+https://github.com/dgasmith/opt_einsum diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index dd73ef19658..6e93ab7a946 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -26,6 +26,7 @@ dependencies: - numbagg - numexpr - numpy + - opt_einsum - packaging - pandas - pint<0.21 diff --git a/doc/conf.py b/doc/conf.py index 23aed3aac46..295f161e545 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -237,6 +237,7 @@ use_repository_button=True, use_issues_button=True, home_page_in_toc=False, + navigation_with_keys=False, extra_footer="""

Xarray is a fiscally sponsored project of NumFOCUS, a nonprofit dedicated to supporting the open-source scientific computing community.
Theme by the Executable Book Project

""", @@ -327,6 +328,7 @@ "sparse": ("https://sparse.pydata.org/en/latest/", None), "cubed": ("https://tom-e-white.com/cubed/", None), "datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None), + # "opt_einsum": ("https://dgasmith.github.io/opt_einsum/", None), } diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c88f685b0ba..b24a19c9129 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,8 @@ v2023.10.2 (unreleased) New Features ~~~~~~~~~~~~ +- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. + By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index e7fa7bec5c0..b16063e0370 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ source-code = "https://github.com/pydata/xarray" dask = "xarray.core.daskmanager:DaskManager" [project.optional-dependencies] -accel = ["scipy", "bottleneck", "numbagg", "flox"] +accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] complete = ["xarray[accel,io,parallel,viz]"] io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] parallel = ["dask[complete]"] @@ -106,6 +106,7 @@ module = [ "numbagg.*", "netCDF4.*", "netcdftime.*", + "opt_einsum.*", "pandas.*", "pooch.*", "PseudoNetCDF.*", diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1b96043f1f5..f506bc97a2c 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1690,8 +1690,8 @@ def dot( dims: Dims = None, **kwargs: Any, ): - """Generalized dot product for xarray objects. Like np.einsum, but - provides a simpler interface based on array dimensions. + """Generalized dot product for xarray objects. Like ``np.einsum``, but + provides a simpler interface based on array dimension names. Parameters ---------- @@ -1701,13 +1701,24 @@ def dot( Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict - Additional keyword arguments passed to numpy.einsum or - dask.array.einsum + Additional keyword arguments passed to ``numpy.einsum`` or + ``dask.array.einsum`` Returns ------- DataArray + See Also + -------- + numpy.einsum + dask.array.einsum + opt_einsum.contract + + Notes + ----- + We recommend installing the optional ``opt_einsum`` package, or alternatively passing ``optimize=True``, + which is passed through to ``np.einsum``, and works for most array backends. + Examples -------- >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=["a", "b"]) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 51b6ff5f59b..b9f7db9737f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,7 +18,6 @@ from numpy import any as array_any # noqa from numpy import ( # noqa around, # noqa - einsum, gradient, isclose, isin, @@ -48,6 +47,17 @@ def get_array_namespace(x): return np +def einsum(*args, **kwargs): + from xarray.core.options import OPTIONS + + if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"): + import opt_einsum + + return opt_einsum.contract(*args, **kwargs) + else: + return np.einsum(*args, **kwargs) + + def _dask_or_eager_func( name, eager_module=np, diff --git a/xarray/core/options.py b/xarray/core/options.py index 118a67559ad..d116c350991 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -28,6 +28,7 @@ "warn_for_unclosed_files", "use_bottleneck", "use_numbagg", + "use_opt_einsum", "use_flox", ] @@ -52,6 +53,7 @@ class T_Options(TypedDict): use_bottleneck: bool use_flox: bool use_numbagg: bool + use_opt_einsum: bool OPTIONS: T_Options = { @@ -75,6 +77,7 @@ class T_Options(TypedDict): "use_bottleneck": True, "use_flox": True, "use_numbagg": True, + "use_opt_einsum": True, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -102,6 +105,7 @@ def _positive_integer(value: int) -> bool: "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), "use_numbagg": lambda value: isinstance(value, bool), + "use_opt_einsum": lambda value: isinstance(value, bool), "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -237,6 +241,8 @@ class set_options: use_numbagg : bool, default: True Whether to use ``numbagg`` to accelerate reductions. Takes precedence over ``use_bottleneck`` when both are True. + use_opt_einsum : bool, default: True + Whether to use ``opt_einsum`` to accelerate dot products. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 7e1105e2e5d..14a7a10f734 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1502,10 +1502,11 @@ def test_dot_dataarray(dtype): data_array = xr.DataArray(data=array1, dims=("x", "y")) other = xr.DataArray(data=array2, dims=("y", "z")) - expected = attach_units( - xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} - ) - actual = xr.dot(data_array, other) + with xr.set_options(use_opt_einsum=False): + expected = attach_units( + xr.dot(strip_units(data_array), strip_units(other)), {None: unit_registry.m} + ) + actual = xr.dot(data_array, other) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -2465,8 +2466,9 @@ def test_binary_operations(self, func, dtype): data_array = xr.DataArray(data=array) units = extract_units(func(array)) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual) @@ -3829,8 +3831,9 @@ def test_computation(self, func, variant, dtype): if not isinstance(func, (function, method)): units.update(extract_units(func(array.reshape(-1)))) - expected = attach_units(func(strip_units(data_array)), units) - actual = func(data_array) + with xr.set_options(use_opt_einsum=False): + expected = attach_units(func(strip_units(data_array)), units) + actual = func(data_array) assert_units_equal(expected, actual) assert_identical(expected, actual)