Skip to content

Commit e1119ca

Browse files
committed
Fix distance (cpu)
1 parent 0c28121 commit e1119ca

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

Diff for: sgkit/distance/api.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ def pairwise_distance(
109109
f"Invalid Device, expected one of {valid_devices}, got: {device}"
110110
)
111111
try:
112-
map_func_name = f"{metric}_map_{device}"
112+
map_pairwise_func_name = f"{metric}_map_pairwise_{device}"
113113
reduce_func_name = f"{metric}_reduce_{device}"
114-
map_func = getattr(metrics, map_func_name)
114+
map_pairwise_func = getattr(metrics, map_pairwise_func_name)
115115
reduce_func = getattr(metrics, reduce_func_name)
116116
n_map_param = metrics.N_MAP_PARAM[metric]
117117
except AttributeError:
@@ -123,24 +123,7 @@ def pairwise_distance(
123123
if x.ndim != 2:
124124
raise ValueError(f"2-dimensional array expected, got '{x.ndim}'")
125125

126-
# setting this variable outside of _pairwise to avoid it's recreation
127-
# in every iteration, which eventually leads to increase in dask
128-
# graph serialisation/deserialisation time significantly
129-
metric_param = np.empty(n_map_param, dtype=x.dtype)
130-
131-
def _pairwise_cpu(f: ArrayLike, g: ArrayLike) -> ArrayLike:
132-
result: ArrayLike = map_func(f[:, None, :], g, metric_param)
133-
# Adding a new axis to help combine chunks along this axis in the
134-
# reduction step (see the _aggregate and _combine functions below).
135-
return result[..., np.newaxis]
136-
137-
def _pairwise_gpu(f: ArrayLike, g: ArrayLike) -> ArrayLike: # pragma: no cover
138-
result = map_func(f, g)
139-
return result[..., np.newaxis]
140-
141-
pairwise_func = _pairwise_cpu
142-
if device == "gpu":
143-
pairwise_func = _pairwise_gpu # pragma: no cover
126+
pairwise_func = map_pairwise_func
144127

145128
# concatenate in blockwise leads to high memory footprints, so instead
146129
# we perform blockwise without contraction followed by reduction.

Diff for: sgkit/distance/metrics.py

+17
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def euclidean_map_cpu(
5959
out[:] = square_sum
6060

6161

62+
def euclidean_map_pairwise_cpu(f: ArrayLike, g: ArrayLike) -> ArrayLike:
63+
metric_param = np.empty(1, dtype=f.dtype)
64+
result: ArrayLike = euclidean_map_cpu(f[:, None, :], g, metric_param)
65+
# Adding a new axis to help combine chunks along this axis in the
66+
# reduction step (see the _aggregate and _combine functions below).
67+
return result[..., np.newaxis]
68+
69+
6270
def euclidean_reduce_cpu(v: ArrayLike) -> ArrayLike: # pragma: no cover
6371
"""Corresponding "reduce" function for euclidean distance.
6472
@@ -138,6 +146,15 @@ def correlation_map_cpu(
138146
)
139147

140148

149+
def correlation_map_pairwise_cpu(f: ArrayLike, g: ArrayLike) -> ArrayLike:
150+
# TODO: note that allocating this array here goes against the advice in api.py about dask graph serialization time
151+
metric_param = np.empty(6, dtype=f.dtype)
152+
result: ArrayLike = correlation_map_cpu(f[:, None, :], g, metric_param)
153+
# Adding a new axis to help combine chunks along this axis in the
154+
# reduction step (see the _aggregate and _combine functions below).
155+
return result[..., np.newaxis]
156+
157+
141158
@numba_guvectorize( # type: ignore
142159
[
143160
"void(float32[:, :], float32[:])",

0 commit comments

Comments
 (0)