@@ -109,9 +109,9 @@ def pairwise_distance(
109
109
f"Invalid Device, expected one of { valid_devices } , got: { device } "
110
110
)
111
111
try :
112
- map_func_name = f"{ metric } _map_ { device } "
112
+ map_pairwise_func_name = f"{ metric } _map_pairwise_ { device } "
113
113
reduce_func_name = f"{ metric } _reduce_{ device } "
114
- map_func = getattr (metrics , map_func_name )
114
+ map_pairwise_func = getattr (metrics , map_pairwise_func_name )
115
115
reduce_func = getattr (metrics , reduce_func_name )
116
116
n_map_param = metrics .N_MAP_PARAM [metric ]
117
117
except AttributeError :
@@ -123,24 +123,7 @@ def pairwise_distance(
123
123
if x .ndim != 2 :
124
124
raise ValueError (f"2-dimensional array expected, got '{ x .ndim } '" )
125
125
126
- # setting this variable outside of _pairwise to avoid it's recreation
127
- # in every iteration, which eventually leads to increase in dask
128
- # graph serialisation/deserialisation time significantly
129
- metric_param = np .empty (n_map_param , dtype = x .dtype )
130
-
131
- def _pairwise_cpu (f : ArrayLike , g : ArrayLike ) -> ArrayLike :
132
- result : ArrayLike = map_func (f [:, None , :], g , metric_param )
133
- # Adding a new axis to help combine chunks along this axis in the
134
- # reduction step (see the _aggregate and _combine functions below).
135
- return result [..., np .newaxis ]
136
-
137
- def _pairwise_gpu (f : ArrayLike , g : ArrayLike ) -> ArrayLike : # pragma: no cover
138
- result = map_func (f , g )
139
- return result [..., np .newaxis ]
140
-
141
- pairwise_func = _pairwise_cpu
142
- if device == "gpu" :
143
- pairwise_func = _pairwise_gpu # pragma: no cover
126
+ pairwise_func = map_pairwise_func
144
127
145
128
# concatenate in blockwise leads to high memory footprints, so instead
146
129
# we perform blockwise without contraction followed by reduction.
0 commit comments