Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations #520

Merged
merged 18 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)


## 0.9.1
Expand Down
7 changes: 7 additions & 0 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ List of compatible Backends
- `Tensorflow <https://www.tensorflow.org/>`_ (all outputs differentiable w.r.t. inputs)
- `Cupy <https://cupy.dev/>`_ (no differentiation, GPU only)

The library automatically detects which backends are available for use. A backend
is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
You can also disable the import of a specific backend library (e.g., to accelerate
loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_<NAME>` with <NAME> in (TORCH,TENSORFLOW,CUPY,JAX).
For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
It's important to note that the `numpy` backend cannot be disabled.


List of compatible modules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
131 changes: 90 additions & 41 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,43 +87,67 @@
# License: MIT License

import numpy as np
import os
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import scipy.special as special
import time
import warnings


DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
DISABLE_CUPY_KEY = 'POT_BACKEND_DISABLE_CUPY'
DISABLE_TF_KEY = 'POT_BACKEND_DISABLE_TENSORFLOW'


try:
import torch
torch_type = torch.Tensor
except ImportError:
if not os.environ.get(DISABLE_TORCH_KEY, False):
try:
import torch
torch_type = torch.Tensor
except ImportError:
torch = False
torch_type = float
else:
torch = False
torch_type = float

try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
if not os.environ.get(DISABLE_JAX_KEY, False):
try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
jax_type = float
else:
jax = False
jax_type = float

try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
if not os.environ.get(DISABLE_CUPY_KEY, False):
try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
cp = False
cp_type = float
else:
cp = False
cp_type = float

try:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
tf_type = tf.Tensor
except ImportError:
if not os.environ.get(DISABLE_TF_KEY, False):
try:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
tf_type = tf.Tensor
except ImportError:
tf = False
tf_type = float
else:
tf = False
tf_type = float

Expand All @@ -132,26 +156,51 @@


# Mapping between argument types and the existing backend
_BACKENDS = []
_BACKEND_IMPLEMENTATIONS = []
_BACKENDS = {}


def register_backend(backend):
_BACKENDS.append(backend)
def _register_backend_implementation(backend_impl):
_BACKEND_IMPLEMENTATIONS.append(backend_impl)


def get_backend_list():
"""Returns the list of available backends"""
return _BACKENDS
def _get_backend_instance(backend_impl):
if backend_impl.__name__ not in _BACKENDS:
_BACKENDS[backend_impl.__name__] = backend_impl()
return _BACKENDS[backend_impl.__name__]


def _check_args_backend(backend, args):
is_instance = set(isinstance(a, backend.__type__) for a in args)
def _check_args_backend(backend_impl, args):
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()

# Oterwise return an error
raise ValueError(str_type_error.format([type(a) for a in args]))
# Otherwise return an error
raise ValueError(str_type_error.format([type(arg) for arg in args]))


def get_backend_list():
"""Returns instances of all available backends.

Note that the function forces all detected implementations
to be instantiated even if specific backend was not use before.
Be careful as instantiation of the backend might lead to side effects,
like GPU memory pre-allocation. See the documentation for more details.
If you only need to know which implementations are available,
use `:py:func:`ot.backend.get_available_backend_implementations`,
which does not force instance of the backend object to be created.
"""
return [
_get_backend_instance(backend_impl)
for backend_impl
in get_available_backend_implementations()
]


def get_available_backend_implementations():
"""Returns the list of available backend implementations."""
return _BACKEND_IMPLEMENTATIONS


def get_backend(*args):
Expand All @@ -167,9 +216,9 @@ def get_backend(*args):
if not len(args) > 0:
raise ValueError(" The function takes at least one (non-None) parameter")

for backend in _BACKENDS:
if _check_args_backend(backend, args):
return backend
for backend_impl in _BACKEND_IMPLEMENTATIONS:
if _check_args_backend(backend_impl, args):
return _get_backend_instance(backend_impl)

raise ValueError("Unknown type of non implemented backend.")

Expand Down Expand Up @@ -1341,7 +1390,7 @@ def matmul(self, a, b):
return np.matmul(a, b)


register_backend(NumpyBackend())
_register_backend_implementation(NumpyBackend)


class JaxBackend(Backend):
Expand Down Expand Up @@ -1710,7 +1759,7 @@ def matmul(self, a, b):

if jax:
# Only register jax backend if it is installed
register_backend(JaxBackend())
_register_backend_implementation(JaxBackend)


class TorchBackend(Backend):
Expand Down Expand Up @@ -2193,7 +2242,7 @@ def matmul(self, a, b):

if torch:
# Only register torch backend if it is installed
register_backend(TorchBackend())
_register_backend_implementation(TorchBackend)


class CupyBackend(Backend): # pragma: no cover
Expand Down Expand Up @@ -2586,7 +2635,7 @@ def matmul(self, a, b):

if cp:
# Only register cp backend if it is installed
register_backend(CupyBackend())
_register_backend_implementation(CupyBackend)


class TensorflowBackend(Backend):
Expand Down Expand Up @@ -3006,4 +3055,4 @@ def matmul(self, a, b):

if tf:
# Only register tensorflow backend if it is installed
register_backend(TensorflowBackend())
_register_backend_implementation(TensorflowBackend)
20 changes: 17 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@

# License: MIT License

import pytest
from ot.backend import jax, tf
from ot.backend import get_backend_list
import functools
import os
import pytest

from ot.backend import get_backend_list, jax, tf


if jax:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax.config import config
config.update("jax_enable_x64", True)

if tf:
# make sure TF doesn't allocate entire GPU
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
try:
tf.config.experimental.set_memory_growth(device, True)
except Exception:
pass

# allow numpy API for TF
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()


backend_list = get_backend_list()


Expand Down