39
39
#include " oneapi/dal/table/row_accessor.hpp"
40
40
41
41
#include " oneapi/dal/detail/common.hpp"
42
+ #include " oneapi/dal/detail/profiler.hpp"
42
43
43
44
namespace oneapi ::dal::knn::backend {
44
45
@@ -72,13 +73,7 @@ class knn_callback_distr {
72
73
result_options_(results),
73
74
query_block_(query_block),
74
75
query_length_(query_length),
75
- k_neighbors_(k_neighbors) {
76
- if (result_options_.test (result_options::responses)) {
77
- this ->temp_resp_ = pr::ndarray<res_t , 2 >::empty (q,
78
- { query_block, k_neighbors },
79
- sycl::usm::alloc::device);
80
- }
81
- }
76
+ k_neighbors_(k_neighbors) {}
82
77
83
78
auto & set_euclidean_distance (bool is_euclidean_distance) {
84
79
this ->compute_sqrt_ = is_euclidean_distance;
@@ -196,50 +191,27 @@ class knn_callback_distr {
196
191
return *this ;
197
192
}
198
193
199
- sycl::event finalize (std::int64_t qb_id,
200
- pr::ndview<idx_t , 2 >& inp_indices,
201
- pr::ndview<Float, 2 >& inp_distances,
202
- const bk::event_vector& deps = {}) {
203
- sycl::event copy_indices, copy_distances, comp_responses;
204
-
205
- const auto bounds = this ->block_bounds (qb_id);
206
-
207
- if (result_options_.test (result_options::indices)) {
208
- copy_indices = this ->output_indices (bounds, inp_indices, deps);
209
- }
210
-
211
- if (result_options_.test (result_options::distances)) {
212
- copy_distances = this ->output_distances (bounds, inp_distances, deps);
213
- }
214
-
215
- if (result_options_.test (result_options::responses)) {
216
- using namespace bk ;
217
- const auto ndeps = deps + copy_indices + copy_distances;
218
- comp_responses = this ->output_responses (bounds, inp_indices, inp_distances, ndeps);
219
- }
220
-
221
- sycl::event::wait_and_throw ({ copy_indices, copy_distances, comp_responses });
222
- return sycl::event ();
223
- }
224
-
225
194
sycl::event operator ()(std::int64_t qb_id,
226
195
pr::ndview<idx_t , 2 >& inp_indices,
227
196
pr::ndview<Float, 2 >& inp_distances,
228
197
const bk::event_vector& deps = {}) {
198
+ ONEDAL_PROFILER_TASK (query_loop.callback , queue_);
229
199
sycl::event copy_actual_dist_event, copy_current_dist_event, copy_actual_indc_event,
230
200
copy_current_indc_event, copy_actual_resp_event, copy_current_resp_event;
231
- const auto & [first, last] = this ->block_bounds (qb_id);
201
+ const auto & bounds = this ->block_bounds (qb_id);
202
+ const auto & [first, last] = bounds;
232
203
const auto len = last - first;
233
204
ONEDAL_ASSERT (last > first);
234
205
ONEDAL_ASSERT (inp_indices.get_dimension (0 ) == len);
235
206
ONEDAL_ASSERT (inp_indices.get_dimension (1 ) == k_neighbors_);
236
207
ONEDAL_ASSERT (inp_distances.get_dimension (0 ) == len);
237
208
ONEDAL_ASSERT (inp_distances.get_dimension (1 ) == k_neighbors_);
238
209
239
- auto inp_responses = this ->temp_resp_ .get_row_slice (0 , len);
210
+ auto current_min_resp_dest = part_responses_.get_col_slice (k_neighbors_, 2 * k_neighbors_)
211
+ .get_row_slice (first, last);
240
212
241
- auto select_inp_resp_event =
242
- pr::select_indexed (queue_, inp_indices, train_responses_, inp_responses , deps);
213
+ copy_current_resp_event =
214
+ pr::select_indexed (queue_, inp_indices, train_responses_, current_min_resp_dest , deps);
243
215
244
216
const pr::ndshape<2 > typical_blocking (last - first, 2 * k_neighbors_);
245
217
auto select = selc_t (queue_, typical_blocking, k_neighbors_);
@@ -250,8 +222,10 @@ class knn_callback_distr {
250
222
251
223
// add global offset value to input indices
252
224
ONEDAL_ASSERT (global_index_offset_ != -1 );
253
- auto treat_event =
254
- pr::treat_indices (queue_, inp_indices, global_index_offset_, { select_inp_resp_event });
225
+ auto treat_event = pr::treat_indices (queue_,
226
+ inp_indices,
227
+ global_index_offset_,
228
+ { copy_current_resp_event });
255
229
256
230
auto actual_min_dist_copy_dest =
257
231
part_distances_.get_col_slice (0 , k_neighbors_).get_row_slice (first, last);
@@ -273,39 +247,48 @@ class knn_callback_distr {
273
247
274
248
auto actual_min_resp_copy_dest =
275
249
part_responses_.get_col_slice (0 , k_neighbors_).get_row_slice (first, last);
276
- auto current_min_resp_dest = part_responses_.get_col_slice (k_neighbors_, 2 * k_neighbors_)
277
- .get_row_slice (first, last);
278
250
copy_actual_resp_event =
279
251
pr::copy (queue_, actual_min_resp_copy_dest, min_resp_dest, { treat_event });
280
- copy_current_resp_event =
281
- pr::copy (queue_, current_min_resp_dest, inp_responses, { treat_event });
282
-
283
- auto kselect_block = part_distances_.get_row_slice (first, last);
284
- auto selt_event = select (queue_,
285
- kselect_block,
286
- k_neighbors_,
287
- min_dist_dest,
288
- min_indc_dest,
289
- { copy_actual_dist_event,
290
- copy_current_dist_event,
291
- copy_actual_indc_event,
292
- copy_current_indc_event,
293
- copy_actual_resp_event,
294
- copy_current_resp_event });
295
- auto resps_event = select_indexed (queue_,
296
- min_indc_dest,
297
- part_responses_.get_row_slice (first, last),
298
- min_resp_dest,
299
- { selt_event });
300
- auto final_event = select_indexed (queue_,
301
- min_indc_dest,
302
- part_indices_.get_row_slice (first, last),
303
- min_indc_dest,
304
- { resps_event });
252
+
253
+ sycl::event select_event;
254
+ {
255
+ ONEDAL_PROFILER_TASK (query_loop.selection , queue_);
256
+ auto kselect_block = part_distances_.get_row_slice (first, last);
257
+ select_event = select (queue_,
258
+ kselect_block,
259
+ k_neighbors_,
260
+ min_dist_dest,
261
+ min_indc_dest,
262
+ { copy_actual_dist_event,
263
+ copy_current_dist_event,
264
+ copy_actual_indc_event,
265
+ copy_current_indc_event,
266
+ copy_actual_resp_event,
267
+ copy_current_resp_event });
268
+ }
269
+ auto select_resp_event = select_indexed (queue_,
270
+ min_indc_dest,
271
+ part_responses_.get_row_slice (first, last),
272
+ min_resp_dest,
273
+ { select_event });
274
+ auto select_indc_event = select_indexed (queue_,
275
+ min_indc_dest,
276
+ part_indices_.get_row_slice (first, last),
277
+ min_indc_dest,
278
+ { select_resp_event });
305
279
if (last_iteration_) {
306
- final_event = finalize (qb_id, indices_, distances_, { final_event });
280
+ sycl::event copy_sqrt_event;
281
+ if (this ->compute_sqrt_ ) {
282
+ copy_sqrt_event =
283
+ copy_with_sqrt (queue_, min_dist_dest, min_dist_dest, { select_indc_event });
284
+ }
285
+ auto final_event = this ->output_responses (bounds,
286
+ indices_,
287
+ distances_,
288
+ { select_indc_event, copy_sqrt_event });
289
+ return final_event;
307
290
}
308
- return final_event ;
291
+ return select_indc_event ;
309
292
}
310
293
311
294
protected:
@@ -320,45 +303,6 @@ class knn_callback_distr {
320
303
return std::make_pair (first, last);
321
304
}
322
305
323
- sycl::event output_distances (const std::pair<idx_t , idx_t >& bnds,
324
- const pr::ndview<dst_t , 2 >& inp_dts,
325
- const bk::event_vector& deps = {}) {
326
- ONEDAL_ASSERT (inp_dts.has_data ());
327
- ONEDAL_ASSERT (this ->result_options_ .test (result_options::distances));
328
-
329
- const auto & [first, last] = bnds;
330
- ONEDAL_ASSERT (last > first);
331
- auto & queue = this ->queue_ ;
332
-
333
- auto out_dts = this ->distances_ .get_row_slice (first, last);
334
- ONEDAL_ASSERT ((last - first) == inp_dts.get_dimension (0 ));
335
- ONEDAL_ASSERT ((last - first) == out_dts.get_dimension (0 ));
336
-
337
- const bool & csqrt = this ->compute_sqrt_ ;
338
- if (!csqrt)
339
- return pr::copy (queue, out_dts, inp_dts, deps);
340
- else
341
- return copy_with_sqrt (queue, inp_dts, out_dts, deps);
342
- }
343
-
344
- sycl::event output_indices (const std::pair<idx_t , idx_t >& bnds,
345
- const pr::ndview<idx_t , 2 >& inp_ids,
346
- const bk::event_vector& deps = {}) {
347
- ONEDAL_ASSERT (inp_ids.has_data ());
348
- ONEDAL_ASSERT (this ->result_options_ .test (result_options::indices));
349
-
350
- const auto & [first, last] = bnds;
351
- ONEDAL_ASSERT (last > first);
352
- auto & queue = this ->queue_ ;
353
-
354
- auto out_ids = this ->indices_ .get_row_slice (first, last);
355
- ONEDAL_ASSERT ((last - first) == inp_ids.get_dimension (0 ));
356
- ONEDAL_ASSERT ((last - first) == out_ids.get_dimension (0 ));
357
- ONEDAL_ASSERT (inp_ids.get_shape () == out_ids.get_shape ());
358
-
359
- return pr::copy (queue, out_ids, inp_ids, deps);
360
- }
361
-
362
306
template <typename T = Task, typename = detail::enable_if_classification_t <T>>
363
307
sycl::event do_ucls (const std::pair<idx_t , idx_t >& bnds,
364
308
const pr::ndview<res_t , 2 >& tmp_rps,
@@ -481,7 +425,6 @@ class knn_callback_distr {
481
425
const result_option_id result_options_;
482
426
const std::int64_t query_block_, query_length_, k_neighbors_;
483
427
pr::ndview<res_t , 1 > train_responses_;
484
- pr::ndarray<res_t , 2 > temp_resp_;
485
428
pr::ndview<res_t , 1 > responses_;
486
429
pr::ndview<res_t , 2 > part_responses_;
487
430
pr::ndview<res_t , 2 > intermediate_responses_;
@@ -519,7 +462,7 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
519
462
// Input arrays test section
520
463
ONEDAL_ASSERT (train.has_data ());
521
464
ONEDAL_ASSERT (query.has_data ());
522
- [[maybe_unused]] auto tcount = train.get_row_count ();
465
+ const auto tcount = train.get_row_count ();
523
466
const auto qcount = query.get_dimension (0 );
524
467
const auto fcount = train.get_column_count ();
525
468
const auto kcount = desc.get_neighbor_count ();
@@ -558,6 +501,13 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
558
501
559
502
comm.allgather (tcount, node_sample_counts.flatten ()).wait ();
560
503
504
+ // TODO: implement max/min for ndarray
505
+ std::int64_t max_tcount = 0 ;
506
+ for (std::int64_t index = 0 ; index < node_sample_counts.get_count (); ++index ) {
507
+ max_tcount = std::max (node_sample_counts.at (index ), max_tcount);
508
+ }
509
+ block_size = std::min (max_tcount, block_size);
510
+
561
511
auto current_rank = comm.get_rank ();
562
512
auto prev_node = (current_rank - 1 + rank_count) % rank_count;
563
513
auto next_node = (current_rank + 1 ) % rank_count;
@@ -631,12 +581,14 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
631
581
distance_impl->get_daal_distance_type () == daal_distance_t ::cosine;
632
582
const bool is_euclidean_distance =
633
583
is_minkowski_distance && (distance_impl->get_degree () == 2.0 );
584
+ ONEDAL_ASSERT (is_minkowski_distance ^ is_chebyshev_distance ^ is_cosine_distance);
634
585
635
586
const auto it = std::find (nodes.begin (), nodes.end (), current_rank);
636
- auto first_block_index = std::distance (nodes.begin (), it);
587
+ auto relative_block_offset = std::distance (nodes.begin (), it);
637
588
ONEDAL_ASSERT (it != nodes.end ());
638
589
639
- for (std::int64_t block_number = 0 ; block_number < block_count; ++block_number) {
590
+ for (std::int64_t relative_block_idx = 0 ; relative_block_idx < block_count;
591
+ ++relative_block_idx) {
640
592
auto current_block = train_block_queue.front ();
641
593
train_block_queue.pop_front ();
642
594
ONEDAL_ASSERT (current_block.has_data ());
@@ -646,19 +598,20 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
646
598
pr::ndview<res_t , 1 >::wrap (current_tresps.get_data (), { current_tresps.get_count () });
647
599
tresps_queue.pop_front ();
648
600
649
- auto block_index = (block_number + first_block_index) % block_count;
650
- ONEDAL_ASSERT (block_index + 1 < bounds_size);
651
- auto actual_rows_in_block = boundaries.at (block_index + 1 ) - boundaries.at (block_index);
601
+ auto absolute_block_idx = (relative_block_idx + relative_block_offset) % block_count;
602
+ ONEDAL_ASSERT (absolute_block_idx + 1 < bounds_size);
603
+ auto actual_rows_in_block =
604
+ boundaries.at (absolute_block_idx + 1 ) - boundaries.at (absolute_block_idx);
652
605
653
606
auto sc = current_block.get_dimension (0 );
654
607
ONEDAL_ASSERT (sc >= actual_rows_in_block);
655
608
auto curr_k = std::min (actual_rows_in_block, kcount);
656
609
auto actual_current_block = current_block.get_row_slice (0 , actual_rows_in_block);
657
610
auto actual_current_tresps = current_tresps_1d.get_slice (0 , actual_rows_in_block);
658
611
659
- callback.set_global_index_offset (boundaries.at (block_index ));
612
+ callback.set_global_index_offset (boundaries.at (absolute_block_idx ));
660
613
callback.set_train_responses (actual_current_tresps);
661
- if (block_number == block_count - 1 ) {
614
+ if (relative_block_idx == block_count - 1 ) {
662
615
callback.set_last_iteration (true );
663
616
}
664
617
if (is_cosine_distance) {
@@ -699,26 +652,24 @@ sycl::event bf_kernel_distr(sycl::queue& queue,
699
652
next_event = search (query, callback, qbcount, curr_k, { next_event });
700
653
}
701
654
702
- auto send_count = current_block.get_count ();
703
- ONEDAL_ASSERT (send_count >= 0 );
704
- ONEDAL_ASSERT (send_count <= de::limits<int >::max ());
705
- // send recv replace
706
- comm.sendrecv_replace (array<Float>::wrap (queue,
707
- current_block.get_mutable_data (),
708
- send_count,
709
- { next_event }),
710
- prev_node,
711
- next_node)
712
- .wait ();
713
- train_block_queue.emplace_back (current_block);
714
- comm.sendrecv_replace (array<res_t >::wrap (queue,
715
- current_tresps.get_mutable_data (),
716
- current_tresps.get_count (),
717
- { next_event }),
718
- prev_node,
719
- next_node)
720
- .wait ();
721
- tresps_queue.emplace_back (current_tresps);
655
+ if (relative_block_idx < block_count - 1 ) {
656
+ ONEDAL_PROFILER_TASK (distributed_loop.sendrecv_replace , queue);
657
+ auto send_count = current_block.get_count ();
658
+ ONEDAL_ASSERT (send_count >= 0 );
659
+ ONEDAL_ASSERT (send_count <= de::limits<int >::max ());
660
+ auto send_train_block = array<Float>::wrap (queue,
661
+ current_block.get_mutable_data (),
662
+ send_count,
663
+ { next_event });
664
+ comm.sendrecv_replace (send_train_block, prev_node, next_node).wait ();
665
+ train_block_queue.emplace_back (current_block);
666
+ auto send_resps_block = array<res_t >::wrap (queue,
667
+ current_tresps.get_mutable_data (),
668
+ current_tresps.get_count (),
669
+ { next_event });
670
+ comm.sendrecv_replace (send_resps_block, prev_node, next_node).wait ();
671
+ tresps_queue.emplace_back (current_tresps);
672
+ }
722
673
}
723
674
724
675
return next_event;
0 commit comments