Skip to content

Commit

Permalink
dist_sum nan workaround from src-d#112
Browse files Browse the repository at this point in the history
  • Loading branch information
Duanyll committed Jan 9, 2023
1 parent 7dcd28c commit 8a78ca4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/kmeans.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __global__ void kmeans_plus_plus(
centroids += (cc - 1) * d_features_size;
const uint32_t local_sample = sample + offset;
if (_eq(samples[local_sample], samples[local_sample])) {
dist = METRIC<M, F>::distance_t(
dist = METRIC<M, F>::distance_t_no_nan(
samples, centroids, d_samples_size, local_sample);
}
float prev_dist;
Expand Down Expand Up @@ -83,7 +83,7 @@ __global__ void kmeans_afkmc2_calc_q_dists(
c1[i] = samples[static_cast<uint64_t>(c1_index) * d_features_size + i];
}
__syncthreads();
dist = METRIC<M, F>::distance_t(samples, c1, d_samples_size, sample);
dist = METRIC<M, F>::distance_t_no_nan(samples, c1, d_samples_size, sample);
dist *= dist;
dists[sample] = dist;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ __global__ void kmeans_afkmc2_min_dist(
}
float min_dist = FLT_MAX;
for (uint32_t c = 0; c < k; c++) {
float dist = METRIC<M, F>::distance_t(
float dist = METRIC<M, F>::distance_t_no_nan(
samples, centroids + c * d_features_size, d_samples_size, choices[chi]);
if (dist < min_dist) {
min_dist = dist;
Expand All @@ -194,7 +194,7 @@ __global__ void kmeans_afkmc2_min_dist_transposed(
for (uint32_t chi = 0; chi < m; chi++) {
float dist = FLT_MAX;
if (c < k) {
dist = METRIC<M, F>::distance_t(
dist = METRIC<M, F>::distance_t_no_nan(
samples, centroids + c * d_features_size, d_samples_size, choices[chi]);
}
float warp_min = warpReduceMin(dist);
Expand Down Expand Up @@ -469,7 +469,7 @@ __global__ void kmeans_yy_init(
// this may happen if the centroid is insane (NaN)
continue;
}
float dist = METRIC<M, F>::distance_t(
float dist = METRIC<M, F>::distance_t_no_nan(
samples, shared_centroids + (c - gc) * d_features_size,
d_samples_size, sample + offset);
if (c != nearest) {
Expand Down Expand Up @@ -569,7 +569,7 @@ __global__ void kmeans_yy_global_filter(
return;
}
upper_bound = 0;
upper_bound = METRIC<M, F>::distance_t(
upper_bound = METRIC<M, F>::distance_t_no_nan(
samples, centroids + cluster * d_features_size,
d_samples_size, sample + offset);
bounds[sample] = upper_bound;
Expand Down Expand Up @@ -640,7 +640,7 @@ __global__ void kmeans_yy_local_filter(
if (second_min_dist < lower_bound) {
continue;
}
float dist = METRIC<M, F>::distance_t(
float dist = METRIC<M, F>::distance_t_no_nan(
samples, shared_centroids + (c - gc) * d_features_size,
d_samples_size, sample + offset);
if (dist < min_dist) {
Expand Down Expand Up @@ -680,7 +680,7 @@ __global__ void kmeans_calc_average_distance(
float dist = 0;
if (sample < length) {
sample += offset;
dist = METRIC<M, F>::distance_t(
dist = METRIC<M, F>::distance_t_no_nan(
samples, centroids + assignments[sample] * d_features_size,
d_samples_size, sample);
}
Expand Down
14 changes: 14 additions & 0 deletions src/metric_abstraction.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ struct METRIC<kmcudaDistanceMetricL2, F> {
return _sqrt(_float(_fin(dist)));
}

FPATTR static float distance_t_no_nan(const F *__restrict__ v1,
const F *__restrict__ v2,
uint64_t v1_size, uint64_t v1_index) {
auto res = distance_t(v1, v2, v1_size, v1_index);
return _eq(res, res) ? res : 0;
}

FPATTR static float distance_tt(const F *__restrict__ v, uint64_t size,
uint64_t index1, uint64_t index2) {
// Kahan summation with inverted c
Expand Down Expand Up @@ -203,6 +210,13 @@ struct METRIC<kmcudaDistanceMetricCosine, F> {
return _float(distance(_const<F>(1), _const<F>(1), prod));
}

FPATTR static float distance_t_no_nan(const F *__restrict__ v1,
const F *__restrict__ v2,
uint64_t v1_size, uint64_t v1_index) {
auto res = distance_t(v1, v2, v1_size, v1_index);
return _eq(res, res) ? res : 0;
}

FPATTR static float distance_tt(const F *__restrict__ v, uint64_t size,
uint64_t index1, uint64_t index2) {
// Kahan summation with inverted c
Expand Down

0 comments on commit 8a78ca4

Please sign in to comment.