Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Only Minkowski metrics can be used for emd_1d #670

Merged
merged 5 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
- Fixed `ot.emd_1d` and `ot.emd2_1d` incorrectly allowing any metric (PR #670, Issue #669)

## 0.9.4
*June 2024*
Expand Down
11 changes: 5 additions & 6 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
v : (nt,) ndarray, float64
Target dirac locations (on the real line)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
are used.
Metric to be used. Only works with either of the strings
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
p: float, optional (default=1.0)
The p-norm to apply for if metric='minkowski'

Expand Down Expand Up @@ -182,8 +180,9 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
elif metric == 'minkowski':
m_ij = math.pow(math.fabs(u[i] - v[j]), p)
else:
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
metric=metric)[0, 0]
raise ValueError("Solver for EMD in 1d only supports metrics " +
"from the following list: " +
"`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`")
if w_i < w_j or j == m - 1:
cost += m_ij * w_i
G[cur_idx] = w_i
Expand Down
23 changes: 14 additions & 9 deletions ot/lp/solver_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- x_a and x_b are the samples
- a and b are the sample weights

When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
This implementation only supports metrics
of the form :math:`d(x, y) = |x - y|^p`.

Uses the algorithm detailed in [1]_

Expand All @@ -167,9 +168,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
b : (nt,) ndarray, float64, optional
Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
Metric to be used. Only works with either of the strings
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
p: float, optional (default=1.0)
The p-norm to apply for if metric='minkowski'
dense: boolean, optional (default=True)
Expand Down Expand Up @@ -234,6 +234,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
"emd_1d should only be used with monodimensional data"
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"
if metric not in ['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']:
raise ValueError(
"Solver for EMD in 1d only supports metrics " +
"from the following list: " +
"`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
)

# if empty array given then use uniform distributions
if a is None or a.ndim == 0 or len(a) == 0:
Expand Down Expand Up @@ -300,7 +306,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- x_a and x_b are the samples
- a and b are the sample weights

When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
This implementation only supports metrics
of the form :math:`d(x, y) = |x - y|^p`.

Uses the algorithm detailed in [1]_

Expand All @@ -315,10 +322,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
b : (nt,) ndarray, float64, optional
Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
Due to implementation details, this function runs faster when
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
are used.
Metric to be used. Only works with either of the strings
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
p: float, optional (default=1.0)
The p-norm to apply for if metric='minkowski'
dense: boolean, optional (default=True)
Expand Down
6 changes: 6 additions & 0 deletions test/test_1d_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def test_emd_1d_emd2_1d_with_weights():
np.testing.assert_allclose(w_u, G.sum(1))
np.testing.assert_allclose(w_v, G.sum(0))

# check that an error is raised if the metric is not a Minkowski one
np.testing.assert_raises(ValueError, ot.emd_1d,
u, v, w_u, w_v, metric='cosine')
np.testing.assert_raises(ValueError, ot.emd2_1d,
u, v, w_u, w_v, metric='cosine')


def test_wasserstein_1d(nx):
rng = np.random.RandomState(0)
Expand Down
Loading