42
42
#include < cub/util_ptx.cuh>
43
43
#include < cub/util_type.cuh>
44
44
45
+ #include < cuda/std/type_traits>
46
+
45
47
#include < cstdint>
46
48
47
49
CUB_NAMESPACE_BEGIN
@@ -287,6 +289,82 @@ VectorizedCopy(int32_t thread_rank, void *dest, ByteOffsetT num_bytes, const voi
287
289
}
288
290
}
289
291
292
+ template <bool IsMemcpy,
293
+ uint32_t LOGICAL_WARP_SIZE,
294
+ typename InputBufferT,
295
+ typename OutputBufferT,
296
+ typename OffsetT,
297
+ typename ::cuda::std::enable_if<IsMemcpy, int >::type = 0 >
298
+ __device__ __forceinline__ void copy_items (InputBufferT input_buffer,
299
+ OutputBufferT output_buffer,
300
+ OffsetT num_bytes,
301
+ OffsetT offset = 0 )
302
+ {
303
+ VectorizedCopy<LOGICAL_WARP_SIZE, uint4 >(threadIdx .x % LOGICAL_WARP_SIZE,
304
+ &reinterpret_cast <char *>(output_buffer)[offset],
305
+ num_bytes,
306
+ &reinterpret_cast <const char *>(input_buffer)[offset]);
307
+ }
308
+
309
+ template <bool IsMemcpy,
310
+ uint32_t LOGICAL_WARP_SIZE,
311
+ typename InputBufferT,
312
+ typename OutputBufferT,
313
+ typename OffsetT,
314
+ typename ::cuda::std::enable_if<!IsMemcpy, int >::type = 0 >
315
+ __device__ __forceinline__ void copy_items (InputBufferT input_buffer,
316
+ OutputBufferT output_buffer,
317
+ OffsetT num_items,
318
+ OffsetT offset = 0 )
319
+ {
320
+ output_buffer += offset;
321
+ input_buffer += offset;
322
+ for (OffsetT i = threadIdx .x % LOGICAL_WARP_SIZE; i < num_items; i += LOGICAL_WARP_SIZE)
323
+ {
324
+ *(output_buffer + i) = *(input_buffer + i);
325
+ }
326
+ }
327
+
328
+ template <bool IsMemcpy,
329
+ typename AliasT,
330
+ typename InputIt,
331
+ typename OffsetT,
332
+ typename ::cuda::std::enable_if<IsMemcpy, int >::type = 0 >
333
+ __device__ __forceinline__ AliasT read_item (InputIt buffer_src, OffsetT offset)
334
+ {
335
+ return *(reinterpret_cast <const AliasT *>(buffer_src) + offset);
336
+ }
337
+
338
+ template <bool IsMemcpy,
339
+ typename AliasT,
340
+ typename InputIt,
341
+ typename OffsetT,
342
+ typename ::cuda::std::enable_if<!IsMemcpy, int >::type = 0 >
343
+ __device__ __forceinline__ AliasT read_item (InputIt buffer_src, OffsetT offset)
344
+ {
345
+ return *(buffer_src + offset);
346
+ }
347
+
348
+ template <bool IsMemcpy,
349
+ typename AliasT,
350
+ typename OutputIt,
351
+ typename OffsetT,
352
+ typename ::cuda::std::enable_if<IsMemcpy, int >::type = 0 >
353
+ __device__ __forceinline__ void write_item (OutputIt buffer_dst, OffsetT offset, AliasT value)
354
+ {
355
+ *(reinterpret_cast <AliasT *>(buffer_dst) + offset) = value;
356
+ }
357
+
358
+ template <bool IsMemcpy,
359
+ typename AliasT,
360
+ typename OutputIt,
361
+ typename OffsetT,
362
+ typename ::cuda::std::enable_if<!IsMemcpy, int >::type = 0 >
363
+ __device__ __forceinline__ void write_item (OutputIt buffer_dst, OffsetT offset, AliasT value)
364
+ {
365
+ *(buffer_dst + offset) = value;
366
+ }
367
+
290
368
/* *
291
369
* @brief A helper class that allows threads to maintain multiple counters, where the counter that
292
370
* shall be incremented can be addressed dynamically without incurring register spillage.
@@ -431,7 +509,8 @@ template <typename AgentMemcpySmallBuffersPolicyT,
431
509
typename BlevBufferTileOffsetsOutItT,
432
510
typename BlockOffsetT,
433
511
typename BLevBufferOffsetTileState,
434
- typename BLevBlockOffsetTileState>
512
+ typename BLevBlockOffsetTileState,
513
+ bool IsMemcpy>
435
514
class AgentBatchMemcpy
436
515
{
437
516
private:
@@ -470,7 +549,14 @@ private:
470
549
// TYPE DECLARATIONS
471
550
// ---------------------------------------------------------------------
472
551
// / Internal load/store type. For byte-wise memcpy, a single-byte type
473
- using AliasT = char ;
552
+ using AliasT = typename ::cuda::std::conditional<
553
+ IsMemcpy,
554
+ std::iterator_traits<char *>,
555
+ std::iterator_traits<cub::detail::value_t <InputBufferIt>>>::type::value_type;
556
+
557
+ // / Types of the input and output buffers
558
+ using InputBufferT = cub::detail::value_t <InputBufferIt>;
559
+ using OutputBufferT = cub::detail::value_t <OutputBufferIt>;
474
560
475
561
// / Type that has to be sufficiently large to hold any of the buffers' sizes.
476
562
// / The BufferSizeIteratorT's value type must be convertible to this type.
@@ -775,17 +861,16 @@ private:
775
861
BlockBufferOffsetT num_wlev_buffers)
776
862
{
777
863
const int32_t warp_id = threadIdx .x / CUB_PTX_WARP_THREADS;
778
- const int32_t warp_lane = threadIdx .x % CUB_PTX_WARP_THREADS;
779
864
constexpr uint32_t WARPS_PER_BLOCK = BLOCK_THREADS / CUB_PTX_WARP_THREADS;
780
865
781
866
for (BlockBufferOffsetT buffer_offset = warp_id; buffer_offset < num_wlev_buffers;
782
867
buffer_offset += WARPS_PER_BLOCK)
783
868
{
784
869
const auto buffer_id = buffers_by_size_class[buffer_offset].buffer_id ;
785
- detail::VectorizedCopy< CUB_PTX_WARP_THREADS, uint4 >(warp_lane,
786
- tile_buffer_dsts [buffer_id],
787
- tile_buffer_sizes [buffer_id],
788
- tile_buffer_srcs [buffer_id]);
870
+ copy_items<IsMemcpy, CUB_PTX_WARP_THREADS, InputBufferT, OutputBufferT, BufferSizeT>(
871
+ tile_buffer_srcs [buffer_id],
872
+ tile_buffer_dsts [buffer_id],
873
+ tile_buffer_sizes [buffer_id]);
789
874
}
790
875
}
791
876
@@ -875,18 +960,18 @@ private:
875
960
#pragma unroll
876
961
for (int32_t i = 0 ; i < TLEV_BYTES_PER_THREAD; i++)
877
962
{
878
- src_byte[i] = reinterpret_cast < const AliasT * >(
879
- tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id ])[zipped_byte_assignment[i]
880
- .buffer_byte_offset ] ;
963
+ src_byte[i] = read_item<IsMemcpy, AliasT, InputBufferT >(
964
+ tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id ],
965
+ zipped_byte_assignment[i] .buffer_byte_offset ) ;
881
966
absolute_tlev_byte_offset += BLOCK_THREADS;
882
967
}
883
968
#pragma unroll
884
969
for (int32_t i = 0 ; i < TLEV_BYTES_PER_THREAD; i++)
885
970
{
886
- reinterpret_cast < AliasT * >(
887
- tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id ])[zipped_byte_assignment[i]
888
- .buffer_byte_offset ] =
889
- src_byte[i];
971
+ write_item<IsMemcpy, AliasT, OutputBufferT >(
972
+ tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id ],
973
+ zipped_byte_assignment[i] .buffer_byte_offset ,
974
+ src_byte[i]) ;
890
975
}
891
976
}
892
977
else
@@ -897,13 +982,13 @@ private:
897
982
{
898
983
if (absolute_tlev_byte_offset < num_total_tlev_bytes)
899
984
{
900
- const AliasT src_byte = reinterpret_cast < const AliasT * >(
901
- tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id ])[zipped_byte_assignment[i]
902
- .buffer_byte_offset ] ;
903
- reinterpret_cast < AliasT * >(
904
- tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id ])[zipped_byte_assignment[i]
905
- .buffer_byte_offset ] =
906
- src_byte;
985
+ const AliasT src_byte = read_item<IsMemcpy, AliasT, InputBufferT >(
986
+ tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id ],
987
+ zipped_byte_assignment[i] .buffer_byte_offset ) ;
988
+ write_item<IsMemcpy, AliasT, OutputBufferT >(
989
+ tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id ],
990
+ zipped_byte_assignment[i] .buffer_byte_offset ,
991
+ src_byte) ;
907
992
}
908
993
absolute_tlev_byte_offset += BLOCK_THREADS;
909
994
}
0 commit comments