Skip to content

Commit 84518c1

Browse files
ianaylreble
andauthored
[UR][SYCL] Implement USM prefetch from device to host in SYCL runtime and UR (#19437)
Add the ability to control USM prefetch direction (host-to-device, device-to-host) in the enqueue_function extension: ```cpp sycl::ext::oneapi::experimental { enum class prefetch_type { device, host }; void prefetch(sycl::queue q, void* ptr, size_t numBytes, prefetch_type type = prefetch_type::device); void prefetch(sycl::handler &h, void* ptr, size_t numBytes, prefetch_type type = prefetch_type::device); } ``` **Note:** - ~~There is a test failure regarding a new ABI symbol: In order to not break the ABI, I added a new handler function to represent prefetch from device-to-host. Despite the precommit failures, this should not be an ABI-breaking change and thus be okay to merge.~~ I modified the ABI symbols tests as a part of this PR. --------- Co-authored-by: Pablo Reble <[email protected]>
1 parent 086eb34 commit 84518c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+742
-164
lines changed

sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <sycl/detail/common.hpp>
1414
#include <sycl/event.hpp>
15+
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp>
1516
#include <sycl/ext/oneapi/experimental/graph.hpp>
1617
#include <sycl/ext/oneapi/properties/properties.hpp>
1718
#include <sycl/handler.hpp>
@@ -369,15 +370,17 @@ void fill(sycl::queue Q, T *Ptr, const T &Pattern, size_t Count,
369370
CodeLoc);
370371
}
371372

372-
inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes) {
373-
CGH.prefetch(Ptr, NumBytes);
373+
inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes,
374+
prefetch_type Type = prefetch_type::device) {
375+
CGH.prefetch(Ptr, NumBytes, Type);
374376
}
375377

376378
inline void prefetch(queue Q, void *Ptr, size_t NumBytes,
379+
prefetch_type Type = prefetch_type::device,
377380
const sycl::detail::code_location &CodeLoc =
378381
sycl::detail::code_location::current()) {
379382
submit(
380-
std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes); },
383+
std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes, Type); },
381384
CodeLoc);
382385
}
383386

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//==--------------- enqueue_types.hpp ---- SYCL enqueue types --------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <string>
12+
13+
namespace sycl {
14+
inline namespace _V1 {
15+
namespace ext::oneapi::experimental {
16+
17+
/// @brief Indicates the destination device for USM data to be prefetched to.
18+
enum class prefetch_type { device, host };
19+
20+
inline std::string prefetchTypeToString(prefetch_type value) {
21+
switch (value) {
22+
case sycl::ext::oneapi::experimental::prefetch_type::device:
23+
return "prefetch_type::device";
24+
case sycl::ext::oneapi::experimental::prefetch_type::host:
25+
return "prefetch_type::host";
26+
default:
27+
return "prefetch_type::unknown";
28+
}
29+
}
30+
31+
} // namespace ext::oneapi::experimental
32+
} // namespace _V1
33+
} // namespace sycl

sycl/include/sycl/handler.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ namespace ext ::oneapi ::experimental {
149149
template <typename, typename> class work_group_memory;
150150
template <typename, typename> class dynamic_work_group_memory;
151151
struct image_descriptor;
152+
enum class prefetch_type;
153+
152154
__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr);
153155
__SYCL_EXPORT void *async_malloc(sycl::handler &h, sycl::usm::alloc kind,
154156
size_t size);
@@ -2627,6 +2629,16 @@ class __SYCL_EXPORT handler {
26272629
/// \param Count is a number of bytes to be prefetched.
26282630
void prefetch(const void *Ptr, size_t Count);
26292631

2632+
/// Provides hints to the runtime library that data should be made available
2633+
/// on a device earlier than Unified Shared Memory would normally require it
2634+
/// to be available.
2635+
///
2636+
/// \param Ptr is a USM pointer to the memory to be prefetched to the device.
2637+
/// \param Count is a number of bytes to be prefetched.
2638+
/// \param Type is type of prefetch, i.e. fetch to device or fetch to host.
2639+
void prefetch(const void *Ptr, size_t Count,
2640+
ext::oneapi::experimental::prefetch_type Type);
2641+
26302642
/// Provides additional information to the underlying runtime about how
26312643
/// different allocations are used.
26322644
///

sycl/source/detail/cg.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,19 @@ class CGFillUSM : public CG {
397397
class CGPrefetchUSM : public CG {
398398
void *MDst;
399399
size_t MLength;
400+
ext::oneapi::experimental::prefetch_type MPrefetchType;
400401

401402
public:
402403
CGPrefetchUSM(void *DstPtr, size_t Length, CG::StorageInitHelper CGData,
404+
ext::oneapi::experimental::prefetch_type PrefetchType,
403405
detail::code_location loc = {})
404406
: CG(CGType::PrefetchUSM, std::move(CGData), std::move(loc)),
405-
MDst(DstPtr), MLength(Length) {}
406-
void *getDst() { return MDst; }
407-
size_t getLength() { return MLength; }
407+
MDst(DstPtr), MLength(Length), MPrefetchType(PrefetchType) {}
408+
void *getDst() const { return MDst; }
409+
size_t getLength() const { return MLength; }
410+
ext::oneapi::experimental::prefetch_type getPrefetchType() const {
411+
return MPrefetchType;
412+
}
408413
};
409414

410415
/// "Advise USM" command group class.

sycl/source/detail/graph/node_impl.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include <sycl/detail/cg_types.hpp> // for CGType
1616
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1717

18-
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
18+
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp> // for prefetchType
19+
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
1920

2021
#include <cstring>
2122
#include <fstream>
@@ -655,7 +656,10 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
655656
sycl::detail::CGPrefetchUSM *Prefetch =
656657
static_cast<sycl::detail::CGPrefetchUSM *>(MCommandGroup.get());
657658
Stream << "Dst: " << Prefetch->getDst()
658-
<< " Length: " << Prefetch->getLength() << "\\n";
659+
<< " Length: " << Prefetch->getLength() << " PrefetchType: "
660+
<< sycl::ext::oneapi::experimental::prefetchTypeToString(
661+
Prefetch->getPrefetchType())
662+
<< "\\n";
659663
}
660664
break;
661665
case sycl::detail::CGType::AdviseUSM:

sycl/source/detail/handler_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <detail/cg.hpp>
1313
#include <detail/kernel_bundle_impl.hpp>
1414
#include <memory>
15+
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp>
1516

1617
namespace sycl {
1718
inline namespace _V1 {
@@ -91,6 +92,10 @@ class handler_impl {
9192
/// property.
9293
bool MIsDeviceImageScoped = false;
9394

95+
/// Direction of USM prefetch / destination device.
96+
sycl::ext::oneapi::experimental::prefetch_type MPrefetchType =
97+
sycl::ext::oneapi::experimental::prefetch_type::device;
98+
9499
// Program scope pipe information.
95100

96101
// Pipe name that uniquely identifies a pipe.

sycl/source/detail/memory_manager.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,18 @@ void MemoryManager::fill_usm(void *Mem, queue_impl &Queue, size_t Length,
925925
DepEvents.size(), DepEvents.data(), OutEvent);
926926
}
927927

928-
void MemoryManager::prefetch_usm(void *Mem, queue_impl &Queue, size_t Length,
929-
std::vector<ur_event_handle_t> DepEvents,
930-
ur_event_handle_t *OutEvent) {
928+
void MemoryManager::prefetch_usm(
929+
void *Mem, queue_impl &Queue, size_t Length,
930+
std::vector<ur_event_handle_t> DepEvents, ur_event_handle_t *OutEvent,
931+
sycl::ext::oneapi::experimental::prefetch_type Dest) {
931932
adapter_impl &Adapter = Queue.getAdapter();
932-
Adapter.call<UrApiKind::urEnqueueUSMPrefetch>(Queue.getHandleRef(), Mem,
933-
Length, 0u, DepEvents.size(),
934-
DepEvents.data(), OutEvent);
933+
ur_usm_migration_flags_t MigrationFlag =
934+
(Dest == sycl::ext::oneapi::experimental::prefetch_type::device)
935+
? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE
936+
: UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
937+
Adapter.call<UrApiKind::urEnqueueUSMPrefetch>(
938+
Queue.getHandleRef(), Mem, Length, MigrationFlag, DepEvents.size(),
939+
DepEvents.data(), OutEvent);
935940
}
936941

937942
void MemoryManager::advise_usm(const void *Mem, queue_impl &Queue,
@@ -1542,11 +1547,16 @@ void MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer(
15421547
sycl::detail::context_impl *Context,
15431548
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
15441549
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
1545-
ur_exp_command_buffer_sync_point_t *OutSyncPoint) {
1550+
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
1551+
sycl::ext::oneapi::experimental::prefetch_type Dest) {
15461552
adapter_impl &Adapter = Context->getAdapter();
1553+
ur_usm_migration_flags_t MigrationFlag =
1554+
(Dest == sycl::ext::oneapi::experimental::prefetch_type::device)
1555+
? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE
1556+
: UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
15471557
Adapter.call<UrApiKind::urCommandBufferAppendUSMPrefetchExp>(
1548-
CommandBuffer, Mem, Length, ur_usm_migration_flags_t(0), Deps.size(),
1549-
Deps.data(), 0u, nullptr, OutSyncPoint, nullptr, nullptr);
1558+
CommandBuffer, Mem, Length, MigrationFlag, Deps.size(), Deps.data(), 0,
1559+
nullptr, OutSyncPoint, nullptr, nullptr);
15501560
}
15511561

15521562
void MemoryManager::ext_oneapi_advise_usm_cmd_buffer(

sycl/source/detail/memory_manager.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <detail/sycl_mem_obj_i.hpp>
1212
#include <sycl/access/access.hpp>
1313
#include <sycl/detail/export.hpp>
14+
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp> // for prefetch_type
1415
#include <sycl/id.hpp>
1516
#include <sycl/property_list.hpp>
1617
#include <sycl/range.hpp>
@@ -146,9 +147,12 @@ class MemoryManager {
146147
std::vector<ur_event_handle_t> DepEvents,
147148
ur_event_handle_t *OutEvent);
148149

149-
static void prefetch_usm(void *Ptr, queue_impl &Queue, size_t Len,
150-
std::vector<ur_event_handle_t> DepEvents,
151-
ur_event_handle_t *OutEvent);
150+
static void
151+
prefetch_usm(void *Ptr, queue_impl &Queue, size_t Len,
152+
std::vector<ur_event_handle_t> DepEvents,
153+
ur_event_handle_t *OutEvent,
154+
sycl::ext::oneapi::experimental::prefetch_type Dest =
155+
sycl::ext::oneapi::experimental::prefetch_type::device);
152156

153157
static void advise_usm(const void *Ptr, queue_impl &Queue, size_t Len,
154158
ur_usm_advice_flags_t Advice,
@@ -245,7 +249,9 @@ class MemoryManager {
245249
sycl::detail::context_impl *Context,
246250
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
247251
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
248-
ur_exp_command_buffer_sync_point_t *OutSyncPoint);
252+
ur_exp_command_buffer_sync_point_t *OutSyncPoint,
253+
sycl::ext::oneapi::experimental::prefetch_type Dest =
254+
sycl::ext::oneapi::experimental::prefetch_type::device);
249255

250256
static void ext_oneapi_advise_usm_cmd_buffer(
251257
sycl::detail::context_impl *Context,

sycl/source/detail/scheduler/commands.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2980,7 +2980,8 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
29802980
if (auto Result = callMemOpHelper(
29812981
MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer,
29822982
&MQueue->getContextImpl(), MCommandBuffer, Prefetch->getDst(),
2983-
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint);
2983+
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint,
2984+
Prefetch->getPrefetchType());
29842985
Result != UR_RESULT_SUCCESS)
29852986
return Result;
29862987

@@ -3291,7 +3292,8 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
32913292
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
32923293
if (auto Result = callMemOpHelper(
32933294
MemoryManager::prefetch_usm, Prefetch->getDst(), *MQueue,
3294-
Prefetch->getLength(), std::move(RawEvents), Event);
3295+
Prefetch->getLength(), std::move(RawEvents), Event,
3296+
Prefetch->getPrefetchType());
32953297
Result != UR_RESULT_SUCCESS)
32963298
return Result;
32973299

sycl/source/handler.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <sycl/stream.hpp>
3737

3838
#include <sycl/ext/oneapi/bindless_images_memory.hpp>
39+
#include <sycl/ext/oneapi/experimental/enqueue_types.hpp>
3940
#include <sycl/ext/oneapi/experimental/graph.hpp>
4041
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
4142
#include <sycl/ext/oneapi/memcpy2d.hpp>
@@ -745,8 +746,9 @@ event handler::finalize() {
745746
MCodeLoc));
746747
break;
747748
case detail::CGType::PrefetchUSM:
748-
CommandGroup.reset(new detail::CGPrefetchUSM(
749-
MDstPtr, MLength, std::move(impl->CGData), MCodeLoc));
749+
CommandGroup.reset(
750+
new detail::CGPrefetchUSM(MDstPtr, MLength, std::move(impl->CGData),
751+
impl->MPrefetchType, MCodeLoc));
750752
break;
751753
case detail::CGType::AdviseUSM:
752754
CommandGroup.reset(new detail::CGAdviseUSM(MDstPtr, MLength, impl->MAdvice,
@@ -1499,6 +1501,16 @@ void handler::prefetch(const void *Ptr, size_t Count) {
14991501
throwIfActionIsCreated();
15001502
MDstPtr = const_cast<void *>(Ptr);
15011503
MLength = Count;
1504+
impl->MPrefetchType = ext::oneapi::experimental::prefetch_type::device;
1505+
setType(detail::CGType::PrefetchUSM);
1506+
}
1507+
1508+
void handler::prefetch(const void *Ptr, size_t Count,
1509+
ext::oneapi::experimental::prefetch_type Type) {
1510+
throwIfActionIsCreated();
1511+
MDstPtr = const_cast<void *>(Ptr);
1512+
MLength = Count;
1513+
impl->MPrefetchType = Type;
15021514
setType(detail::CGType::PrefetchUSM);
15031515
}
15041516

0 commit comments

Comments
 (0)