Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Debug convolutional methods that compute barycenters to work with different devices. #533

Merged
merged 12 commits into from
Oct 18, 2023
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
+ The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
50 changes: 34 additions & 16 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@
#
# License: MIT License

import numpy as np
import os
import scipy
import scipy.linalg
from scipy.sparse import issparse, coo_matrix, csr_matrix
import scipy.special as special
import time
import warnings

import numpy as np
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import coo_matrix, csr_matrix, issparse

DISABLE_TORCH_KEY = 'POT_BACKEND_DISABLE_PYTORCH'
DISABLE_JAX_KEY = 'POT_BACKEND_DISABLE_JAX'
Expand Down Expand Up @@ -650,7 +650,7 @@ def std(self, a, axis=None):
"""
raise NotImplementedError()

def linspace(self, start, stop, num):
def linspace(self, start, stop, num, type_as=None):
r"""
Returns a specified number of evenly spaced values over a given interval.

Expand Down Expand Up @@ -1208,8 +1208,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return np.std(a, axis=axis)

def linspace(self, start, stop, num):
return np.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return np.linspace(start, stop, num)
else:
return np.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return np.meshgrid(a, b)
Expand Down Expand Up @@ -1579,8 +1582,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return jnp.std(a, axis=axis)

def linspace(self, start, stop, num):
return jnp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return jnp.linspace(start, stop, num)
else:
return self._change_device(jnp.linspace(start, stop, num, dtype=type_as.dtype), type_as)

def meshgrid(self, a, b):
return jnp.meshgrid(a, b)
Expand Down Expand Up @@ -1986,6 +1992,7 @@ def concatenate(self, arrays, axis=0):

def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad

# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
Expand All @@ -2006,6 +2013,7 @@ def mean(self, a, axis=None):

def median(self, a, axis=None):
from packaging import version

# Since version 1.11.0, interpolation is available
if version.parse(torch.__version__) >= version.parse("1.11.0"):
if axis is not None:
Expand All @@ -2026,8 +2034,11 @@ def std(self, a, axis=None):
else:
return torch.std(a, unbiased=False)

def linspace(self, start, stop, num):
return torch.linspace(start, stop, num, dtype=torch.float64)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return torch.linspace(start, stop, num)
else:
return torch.linspace(start, stop, num, dtype=type_as.dtype, device=type_as.device)

def meshgrid(self, a, b):
try:
Expand Down Expand Up @@ -2427,8 +2438,12 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return cp.std(a, axis=axis)

def linspace(self, start, stop, num):
return cp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return cp.linspace(start, stop, num)
else:
with cp.cuda.Device(type_as.device):
return cp.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return cp.meshgrid(a, b)
Expand Down Expand Up @@ -2834,8 +2849,11 @@ def median(self, a, axis=None):
def std(self, a, axis=None):
return tnp.std(a, axis=axis)

def linspace(self, start, stop, num):
return tnp.linspace(start, stop, num)
def linspace(self, start, stop, num, type_as=None):
if type_as is None:
return tnp.linspace(start, stop, num)
else:
return tnp.linspace(start, stop, num, dtype=type_as.dtype)

def meshgrid(self, a, b):
return tnp.meshgrid(a, b)
Expand Down
19 changes: 10 additions & 9 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import numpy as np
from scipy.optimize import fmin_l_bfgs_b

from ot.utils import unif, dist, list_to_array
from ot.utils import dist, list_to_array, unif

from .backend import get_backend


Expand Down Expand Up @@ -2217,11 +2218,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1])
t = nx.linspace(0, 1, A.shape[1], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, A.shape[2])
t = nx.linspace(0, 1, A.shape[2], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

Expand Down Expand Up @@ -2295,11 +2296,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

Expand Down Expand Up @@ -2452,11 +2453,11 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

Expand Down Expand Up @@ -2532,11 +2533,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

Expand Down
12 changes: 5 additions & 7 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
#
# License: MIT License

import ot
import ot.backend
from ot.backend import torch, jax, tf

import pytest

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal_nulp

from ot.backend import get_backend, get_backend_list, to_numpy
import ot
import ot.backend
from ot.backend import get_backend, get_backend_list, jax, tf, to_numpy, torch


def test_get_backend_list():
Expand Down Expand Up @@ -507,6 +504,7 @@ def test_func_backends(nx):
lst_name.append('std')

A = nx.linspace(0, 1, 50)
A = nx.linspace(0, 1, 50, type_as=Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('linspace')

Expand Down
Loading