Skip to content

Commit 4ff1b59

Browse files
authored
Distributed kNN scalability and optimizations (uxlfoundation#2558)
* Profiling additions for benchmarking * dblock cap+last iter,split_table profile,var names * trying revert of data_management * custom max and split table event * address some todos and cleanup finalize * remove temp_resp_ + clang * send recv replace debug * updated debug * extended profiling * temporary for CI build * cleanup and removal of unneeded profiling * syncing data_management with master * I_MPI_OFFLOAD condition for green bazel * temporary conditionals add for bench * for bench only * detailed select_indexed profiling * removing select_indexed_local calls * restoring communicator (see uxlfoundation#2577) * select_indexed debugging removals * search_dpc debugging cleanup * knn cleanup and clang * single gpu/distributed unification * addressing comments * correction to previous * clean up comments * addressing some comments * clang
1 parent a1075af commit 4ff1b59

File tree

5 files changed

+138
-177
lines changed

5 files changed

+138
-177
lines changed

cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "oneapi/dal/table/row_accessor.hpp"
3838

3939
#include "oneapi/dal/detail/common.hpp"
40+
#include "oneapi/dal/detail/profiler.hpp"
4041

4142
namespace oneapi::dal::knn::backend {
4243

@@ -166,6 +167,7 @@ class knn_callback {
166167
pr::ndview<idx_t, 2>& inp_indices,
167168
pr::ndview<Float, 2>& inp_distances,
168169
const bk::event_vector& deps = {}) {
170+
ONEDAL_PROFILER_TASK(query_loop.callback, queue_);
169171
sycl::event copy_indices, copy_distances, comp_responses;
170172

171173
const auto bounds = this->block_bounds(qb_id);
@@ -473,6 +475,7 @@ sycl::event bf_kernel(sycl::queue& queue,
473475
distance_impl->get_daal_distance_type() == daal_distance_t::cosine;
474476
const bool is_euclidean_distance =
475477
is_minkowski_distance && (distance_impl->get_degree() == 2.0);
478+
ONEDAL_ASSERT(is_minkowski_distance ^ is_chebyshev_distance ^ is_cosine_distance);
476479

477480
sycl::event search_event;
478481

cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp

+87-136
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "oneapi/dal/table/row_accessor.hpp"
4040

4141
#include "oneapi/dal/detail/common.hpp"
42+
#include "oneapi/dal/detail/profiler.hpp"
4243

4344
namespace oneapi::dal::knn::backend {
4445

@@ -72,13 +73,7 @@ class knn_callback_distr {
7273
result_options_(results),
7374
query_block_(query_block),
7475
query_length_(query_length),
75-
k_neighbors_(k_neighbors) {
76-
if (result_options_.test(result_options::responses)) {
77-
this->temp_resp_ = pr::ndarray<res_t, 2>::empty(q,
78-
{ query_block, k_neighbors },
79-
sycl::usm::alloc::device);
80-
}
81-
}
76+
k_neighbors_(k_neighbors) {}
8277

8378
auto& set_euclidean_distance(bool is_euclidean_distance) {
8479
this->compute_sqrt_ = is_euclidean_distance;
@@ -196,50 +191,27 @@ class knn_callback_distr {
196191
return *this;
197192
}
198193

199-
sycl::event finalize(std::int64_t qb_id,
200-
pr::ndview<idx_t, 2>& inp_indices,
201-
pr::ndview<Float, 2>& inp_distances,
202-
const bk::event_vector& deps = {}) {
203-
sycl::event copy_indices, copy_distances, comp_responses;
204-
205-
const auto bounds = this->block_bounds(qb_id);
206-
207-
if (result_options_.test(result_options::indices)) {
208-
copy_indices = this->output_indices(bounds, inp_indices, deps);
209-
}
210-
211-
if (result_options_.test(result_options::distances)) {
212-
copy_distances = this->output_distances(bounds, inp_distances, deps);
213-
}
214-
215-
if (result_options_.test(result_options::responses)) {
216-
using namespace bk;
217-
const auto ndeps = deps + copy_indices + copy_distances;
218-
comp_responses = this->output_responses(bounds, inp_indices, inp_distances, ndeps);
219-
}
220-
221-
sycl::event::wait_and_throw({ copy_indices, copy_distances, comp_responses });
222-
return sycl::event();
223-
}
224-
225194
sycl::event operator()(std::int64_t qb_id,
226195
pr::ndview<idx_t, 2>& inp_indices,
227196
pr::ndview<Float, 2>& inp_distances,
228197
const bk::event_vector& deps = {}) {
198+
ONEDAL_PROFILER_TASK(query_loop.callback, queue_);
229199
sycl::event copy_actual_dist_event, copy_current_dist_event, copy_actual_indc_event,
230200
copy_current_indc_event, copy_actual_resp_event, copy_current_resp_event;
231-
const auto& [first, last] = this->block_bounds(qb_id);
201+
const auto& bounds = this->block_bounds(qb_id);
202+
const auto& [first, last] = bounds;
232203
const auto len = last - first;
233204
ONEDAL_ASSERT(last > first);
234205
ONEDAL_ASSERT(inp_indices.get_dimension(0) == len);
235206
ONEDAL_ASSERT(inp_indices.get_dimension(1) == k_neighbors_);
236207
ONEDAL_ASSERT(inp_distances.get_dimension(0) == len);
237208
ONEDAL_ASSERT(inp_distances.get_dimension(1) == k_neighbors_);
238209

239-
auto inp_responses = this->temp_resp_.get_row_slice(0, len);
210+
auto current_min_resp_dest = part_responses_.get_col_slice(k_neighbors_, 2 * k_neighbors_)
211+
.get_row_slice(first, last);
240212

241-
auto select_inp_resp_event =
242-
pr::select_indexed(queue_, inp_indices, train_responses_, inp_responses, deps);
213+
copy_current_resp_event =
214+
pr::select_indexed(queue_, inp_indices, train_responses_, current_min_resp_dest, deps);
243215

244216
const pr::ndshape<2> typical_blocking(last - first, 2 * k_neighbors_);
245217
auto select = selc_t(queue_, typical_blocking, k_neighbors_);
@@ -250,8 +222,10 @@ class knn_callback_distr {
250222

251223
// add global offset value to input indices
252224
ONEDAL_ASSERT(global_index_offset_ != -1);
253-
auto treat_event =
254-
pr::treat_indices(queue_, inp_indices, global_index_offset_, { select_inp_resp_event });
225+
auto treat_event = pr::treat_indices(queue_,
226+
inp_indices,
227+
global_index_offset_,
228+
{ copy_current_resp_event });
255229

256230
auto actual_min_dist_copy_dest =
257231
part_distances_.get_col_slice(0, k_neighbors_).get_row_slice(first, last);
@@ -273,39 +247,48 @@ class knn_callback_distr {
273247

274248
auto actual_min_resp_copy_dest =
275249
part_responses_.get_col_slice(0, k_neighbors_).get_row_slice(first, last);
276-
auto current_min_resp_dest = part_responses_.get_col_slice(k_neighbors_, 2 * k_neighbors_)
277-
.get_row_slice(first, last);
278250
copy_actual_resp_event =
279251
pr::copy(queue_, actual_min_resp_copy_dest, min_resp_dest, { treat_event });
280-
copy_current_resp_event =
281-
pr::copy(queue_, current_min_resp_dest, inp_responses, { treat_event });
282-
283-
auto kselect_block = part_distances_.get_row_slice(first, last);
284-
auto selt_event = select(queue_,
285-
kselect_block,
286-
k_neighbors_,
287-
min_dist_dest,
288-
min_indc_dest,
289-
{ copy_actual_dist_event,
290-
copy_current_dist_event,
291-
copy_actual_indc_event,
292-
copy_current_indc_event,
293-
copy_actual_resp_event,
294-
copy_current_resp_event });
295-
auto resps_event = select_indexed(queue_,
296-
min_indc_dest,
297-
part_responses_.get_row_slice(first, last),
298-
min_resp_dest,
299-
{ selt_event });
300-
auto final_event = select_indexed(queue_,
301-
min_indc_dest,
302-
part_indices_.get_row_slice(first, last),
303-
min_indc_dest,
304-
{ resps_event });
252+
253+
sycl::event select_event;
254+
{
255+
ONEDAL_PROFILER_TASK(query_loop.selection, queue_);
256+
auto kselect_block = part_distances_.get_row_slice(first, last);
257+
select_event = select(queue_,
258+
kselect_block,
259+
k_neighbors_,
260+
min_dist_dest,
261+
min_indc_dest,
262+
{ copy_actual_dist_event,
263+
copy_current_dist_event,
264+
copy_actual_indc_event,
265+
copy_current_indc_event,
266+
copy_actual_resp_event,
267+
copy_current_resp_event });
268+
}
269+
auto select_resp_event = select_indexed(queue_,
270+
min_indc_dest,
271+
part_responses_.get_row_slice(first, last),
272+
min_resp_dest,
273+
{ select_event });
274+
auto select_indc_event = select_indexed(queue_,
275+
min_indc_dest,
276+
part_indices_.get_row_slice(first, last),
277+
min_indc_dest,
278+
{ select_resp_event });
305279
if (last_iteration_) {
306-
final_event = finalize(qb_id, indices_, distances_, { final_event });
280+
sycl::event copy_sqrt_event;
281+
if (this->compute_sqrt_) {
282+
copy_sqrt_event =
283+
copy_with_sqrt(queue_, min_dist_dest, min_dist_dest, { select_indc_event });
284+
}
285+
auto final_event = this->output_responses(bounds,
286+
indices_,
287+
distances_,
288+
{ select_indc_event, copy_sqrt_event });
289+
return final_event;
307290
}
308-
return final_event;
291+
return select_indc_event;
309292
}
310293

311294
protected:
@@ -320,45 +303,6 @@ class knn_callback_distr {
320303
return std::make_pair(first, last);
321304
}
322305

323-
sycl::event output_distances(const std::pair<idx_t, idx_t>& bnds,
324-
const pr::ndview<dst_t, 2>& inp_dts,
325-
const bk::event_vector& deps = {}) {
326-
ONEDAL_ASSERT(inp_dts.has_data());
327-
ONEDAL_ASSERT(this->result_options_.test(result_options::distances));
328-
329-
const auto& [first, last] = bnds;
330-
ONEDAL_ASSERT(last > first);
331-
auto& queue = this->queue_;
332-
333-
auto out_dts = this->distances_.get_row_slice(first, last);
334-
ONEDAL_ASSERT((last - first) == inp_dts.get_dimension(0));
335-
ONEDAL_ASSERT((last - first) == out_dts.get_dimension(0));
336-
337-
const bool& csqrt = this->compute_sqrt_;
338-
if (!csqrt)
339-
return pr::copy(queue, out_dts, inp_dts, deps);
340-
else
341-
return copy_with_sqrt(queue, inp_dts, out_dts, deps);
342-
}
343-
344-
sycl::event output_indices(const std::pair<idx_t, idx_t>& bnds,
345-
const pr::ndview<idx_t, 2>& inp_ids,
346-
const bk::event_vector& deps = {}) {
347-
ONEDAL_ASSERT(inp_ids.has_data());
348-
ONEDAL_ASSERT(this->result_options_.test(result_options::indices));
349-
350-
const auto& [first, last] = bnds;
351-
ONEDAL_ASSERT(last > first);
352-
auto& queue = this->queue_;
353-
354-
auto out_ids = this->indices_.get_row_slice(first, last);
355-
ONEDAL_ASSERT((last - first) == inp_ids.get_dimension(0));
356-
ONEDAL_ASSERT((last - first) == out_ids.get_dimension(0));
357-
ONEDAL_ASSERT(inp_ids.get_shape() == out_ids.get_shape());
358-
359-
return pr::copy(queue, out_ids, inp_ids, deps);
360-
}
361-
362306
template <typename T = Task, typename = detail::enable_if_classification_t<T>>
363307
sycl::event do_ucls(const std::pair<idx_t, idx_t>& bnds,
364308
const pr::ndview<res_t, 2>& tmp_rps,
@@ -481,7 +425,6 @@ class knn_callback_distr {
481425
const result_option_id result_options_;
482426
const std::int64_t query_block_, query_length_, k_neighbors_;
483427
pr::ndview<res_t, 1> train_responses_;
484-
pr::ndarray<res_t, 2> temp_resp_;
485428
pr::ndview<res_t, 1> responses_;
486429
pr::ndview<res_t, 2> part_responses_;
487430
pr::ndview<res_t, 2> intermediate_responses_;
@@ -519,7 +462,7 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
519462
// Input arrays test section
520463
ONEDAL_ASSERT(train.has_data());
521464
ONEDAL_ASSERT(query.has_data());
522-
[[maybe_unused]] auto tcount = train.get_row_count();
465+
const auto tcount = train.get_row_count();
523466
const auto qcount = query.get_dimension(0);
524467
const auto fcount = train.get_column_count();
525468
const auto kcount = desc.get_neighbor_count();
@@ -558,6 +501,13 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
558501

559502
comm.allgather(tcount, node_sample_counts.flatten()).wait();
560503

504+
// TODO: implement max/min for ndarray
505+
std::int64_t max_tcount = 0;
506+
for (std::int64_t index = 0; index < node_sample_counts.get_count(); ++index) {
507+
max_tcount = std::max(node_sample_counts.at(index), max_tcount);
508+
}
509+
block_size = std::min(max_tcount, block_size);
510+
561511
auto current_rank = comm.get_rank();
562512
auto prev_node = (current_rank - 1 + rank_count) % rank_count;
563513
auto next_node = (current_rank + 1) % rank_count;
@@ -631,12 +581,14 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
631581
distance_impl->get_daal_distance_type() == daal_distance_t::cosine;
632582
const bool is_euclidean_distance =
633583
is_minkowski_distance && (distance_impl->get_degree() == 2.0);
584+
ONEDAL_ASSERT(is_minkowski_distance ^ is_chebyshev_distance ^ is_cosine_distance);
634585

635586
const auto it = std::find(nodes.begin(), nodes.end(), current_rank);
636-
auto first_block_index = std::distance(nodes.begin(), it);
587+
auto relative_block_offset = std::distance(nodes.begin(), it);
637588
ONEDAL_ASSERT(it != nodes.end());
638589

639-
for (std::int64_t block_number = 0; block_number < block_count; ++block_number) {
590+
for (std::int64_t relative_block_idx = 0; relative_block_idx < block_count;
591+
++relative_block_idx) {
640592
auto current_block = train_block_queue.front();
641593
train_block_queue.pop_front();
642594
ONEDAL_ASSERT(current_block.has_data());
@@ -646,19 +598,20 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
646598
pr::ndview<res_t, 1>::wrap(current_tresps.get_data(), { current_tresps.get_count() });
647599
tresps_queue.pop_front();
648600

649-
auto block_index = (block_number + first_block_index) % block_count;
650-
ONEDAL_ASSERT(block_index + 1 < bounds_size);
651-
auto actual_rows_in_block = boundaries.at(block_index + 1) - boundaries.at(block_index);
601+
auto absolute_block_idx = (relative_block_idx + relative_block_offset) % block_count;
602+
ONEDAL_ASSERT(absolute_block_idx + 1 < bounds_size);
603+
auto actual_rows_in_block =
604+
boundaries.at(absolute_block_idx + 1) - boundaries.at(absolute_block_idx);
652605

653606
auto sc = current_block.get_dimension(0);
654607
ONEDAL_ASSERT(sc >= actual_rows_in_block);
655608
auto curr_k = std::min(actual_rows_in_block, kcount);
656609
auto actual_current_block = current_block.get_row_slice(0, actual_rows_in_block);
657610
auto actual_current_tresps = current_tresps_1d.get_slice(0, actual_rows_in_block);
658611

659-
callback.set_global_index_offset(boundaries.at(block_index));
612+
callback.set_global_index_offset(boundaries.at(absolute_block_idx));
660613
callback.set_train_responses(actual_current_tresps);
661-
if (block_number == block_count - 1) {
614+
if (relative_block_idx == block_count - 1) {
662615
callback.set_last_iteration(true);
663616
}
664617
if (is_cosine_distance) {
@@ -699,26 +652,24 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
699652
next_event = search(query, callback, qbcount, curr_k, { next_event });
700653
}
701654

702-
auto send_count = current_block.get_count();
703-
ONEDAL_ASSERT(send_count >= 0);
704-
ONEDAL_ASSERT(send_count <= de::limits<int>::max());
705-
// send recv replace
706-
comm.sendrecv_replace(array<Float>::wrap(queue,
707-
current_block.get_mutable_data(),
708-
send_count,
709-
{ next_event }),
710-
prev_node,
711-
next_node)
712-
.wait();
713-
train_block_queue.emplace_back(current_block);
714-
comm.sendrecv_replace(array<res_t>::wrap(queue,
715-
current_tresps.get_mutable_data(),
716-
current_tresps.get_count(),
717-
{ next_event }),
718-
prev_node,
719-
next_node)
720-
.wait();
721-
tresps_queue.emplace_back(current_tresps);
655+
if (relative_block_idx < block_count - 1) {
656+
ONEDAL_PROFILER_TASK(distributed_loop.sendrecv_replace, queue);
657+
auto send_count = current_block.get_count();
658+
ONEDAL_ASSERT(send_count >= 0);
659+
ONEDAL_ASSERT(send_count <= de::limits<int>::max());
660+
auto send_train_block = array<Float>::wrap(queue,
661+
current_block.get_mutable_data(),
662+
send_count,
663+
{ next_event });
664+
comm.sendrecv_replace(send_train_block, prev_node, next_node).wait();
665+
train_block_queue.emplace_back(current_block);
666+
auto send_resps_block = array<res_t>::wrap(queue,
667+
current_tresps.get_mutable_data(),
668+
current_tresps.get_count(),
669+
{ next_event });
670+
comm.sendrecv_replace(send_resps_block, prev_node, next_node).wait();
671+
tresps_queue.emplace_back(current_tresps);
672+
}
722673
}
723674

724675
return next_event;

0 commit comments

Comments
 (0)