Skip to content

Commit a28356f

Browse files
[NFCI][SYCL] Keep raw ptr/ref to devices/platforms in context_impl
Similar to intel#19613. Refactoring has started in intel#18143 intel#18251 This didn't work back then due to some unittests failing (because `global_handler` shutdown is different with unittests) but it doesn't seem to be an issue any more, probably due to other refactoring PRs that prefer raw ptr/refs over `std::shared_ptr` elsewhere.
1 parent e0760a3 commit a28356f

File tree

2 files changed

+59
-64
lines changed

2 files changed

+59
-64
lines changed

sycl/source/detail/context_impl.cpp

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,15 @@ namespace sycl {
2929
inline namespace _V1 {
3030
namespace detail {
3131

32-
context_impl::context_impl(const std::vector<sycl::device> Devices,
33-
async_handler AsyncHandler,
32+
context_impl::context_impl(devices_range Devices, async_handler AsyncHandler,
3433
const property_list &PropList, private_tag)
3534
: MOwnedByRuntime(true), MAsyncHandler(std::move(AsyncHandler)),
36-
MDevices(std::move(Devices)), MContext(nullptr),
37-
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())),
38-
MPropList(PropList), MKernelProgramCache(*this),
39-
MSupportBufferLocationByDevices(NotChecked) {
35+
MDevices(Devices.to<std::vector<device_impl *>>()), MContext(nullptr),
36+
MPlatform(MDevices[0]->getPlatformImpl()), MPropList(PropList),
37+
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
4038
verifyProps(PropList);
4139
std::vector<ur_device_handle_t> DeviceIds;
42-
for (const auto &D : MDevices) {
40+
for (device_impl &D : devices_range{MDevices}) {
4341
if (D.has(aspect::ext_oneapi_is_composite)) {
4442
// Component devices are considered to be descendent devices from a
4543
// composite device and therefore context created for a composite
@@ -52,7 +50,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
5250
DeviceIds.push_back(getSyclObjImpl(CD)->getHandleRef());
5351
}
5452

55-
DeviceIds.push_back(getSyclObjImpl(D)->getHandleRef());
53+
DeviceIds.push_back(D.getHandleRef());
5654
}
5755

5856
getAdapter().call<UrApiKind::urContextCreate>(
@@ -61,39 +59,42 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
6159

6260
context_impl::context_impl(ur_context_handle_t UrContext,
6361
async_handler AsyncHandler, adapter_impl &Adapter,
64-
const std::vector<sycl::device> &DeviceList,
65-
bool OwnedByRuntime, private_tag)
62+
devices_range DeviceList, bool OwnedByRuntime,
63+
private_tag)
6664
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(std::move(AsyncHandler)),
67-
MDevices(DeviceList), MContext(UrContext), MPlatform(),
65+
MDevices([&]() {
66+
if (!DeviceList.empty())
67+
return DeviceList.to<std::vector<device_impl *>>();
68+
69+
std::vector<ur_device_handle_t> DeviceIds;
70+
uint32_t DevicesNum = 0;
71+
// TODO catch an exception and put it to list of asynchronous
72+
// exceptions.
73+
Adapter.call<UrApiKind::urContextGetInfo>(
74+
UrContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum),
75+
&DevicesNum, nullptr);
76+
DeviceIds.resize(DevicesNum);
77+
// TODO catch an exception and put it to list of asynchronous
78+
// exceptions.
79+
Adapter.call<UrApiKind::urContextGetInfo>(
80+
UrContext, UR_CONTEXT_INFO_DEVICES,
81+
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);
82+
83+
if (DeviceIds.empty())
84+
throw exception(
85+
make_error_code(errc::invalid),
86+
"No devices in the provided device list and native context.");
87+
88+
platform_impl &Platform =
89+
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
90+
std::vector<device_impl *> Devices;
91+
for (ur_device_handle_t Dev : DeviceIds)
92+
Devices.emplace_back(&Platform.getOrMakeDeviceImpl(Dev));
93+
94+
return Devices;
95+
}()),
96+
MContext(UrContext), MPlatform(MDevices[0]->getPlatformImpl()),
6897
MKernelProgramCache(*this), MSupportBufferLocationByDevices(NotChecked) {
69-
if (!MDevices.empty()) {
70-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
71-
} else {
72-
std::vector<ur_device_handle_t> DeviceIds;
73-
uint32_t DevicesNum = 0;
74-
// TODO catch an exception and put it to list of asynchronous exceptions
75-
Adapter.call<UrApiKind::urContextGetInfo>(
76-
MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
77-
nullptr);
78-
DeviceIds.resize(DevicesNum);
79-
// TODO catch an exception and put it to list of asynchronous exceptions
80-
Adapter.call<UrApiKind::urContextGetInfo>(
81-
MContext, UR_CONTEXT_INFO_DEVICES,
82-
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);
83-
84-
if (DeviceIds.empty())
85-
throw exception(
86-
make_error_code(errc::invalid),
87-
"No devices in the provided device list and native context.");
88-
89-
platform_impl &Platform =
90-
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
91-
for (ur_device_handle_t Dev : DeviceIds) {
92-
MDevices.emplace_back(
93-
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(Dev)));
94-
}
95-
MPlatform = Platform.shared_from_this();
96-
}
9798
// TODO catch an exception and put it to list of asynchronous exceptions
9899
// getAdapter() will be the same as the Adapter passed. This should be taken
99100
// care of when creating device object.
@@ -144,12 +145,12 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
144145
this->getAdapter());
145146
}
146147
template <> platform context_impl::get_info<info::context::platform>() const {
147-
return createSyclObjFromImpl<platform>(*MPlatform);
148+
return createSyclObjFromImpl<platform>(MPlatform);
148149
}
149150
template <>
150151
std::vector<sycl::device>
151152
context_impl::get_info<info::context::devices>() const {
152-
return MDevices;
153+
return devices_range{MDevices}.to<std::vector<sycl::device>>();
153154
}
154155
template <>
155156
std::vector<sycl::memory_order>
@@ -219,7 +220,7 @@ context_impl::get_backend_info<info::platform::version>() const {
219220
"the info::platform::version info descriptor can "
220221
"only be queried with an OpenCL backend");
221222
}
222-
return MDevices[0].get_platform().get_info<info::platform::version>();
223+
return MDevices[0]->get_platform().get_info<info::platform::version>();
223224
}
224225
#endif
225226

@@ -271,17 +272,17 @@ KernelProgramCache &context_impl::getKernelProgramCache() const {
271272
}
272273

273274
bool context_impl::hasDevice(const detail::device_impl &Device) const {
274-
for (auto D : MDevices)
275-
if (getSyclObjImpl(D).get() == &Device)
275+
for (device_impl *D : MDevices)
276+
if (D == &Device)
276277
return true;
277278
return false;
278279
}
279280

280281
device_impl *
281282
context_impl::findMatchingDeviceImpl(ur_device_handle_t &DeviceUR) const {
282-
for (device D : MDevices)
283-
if (getSyclObjImpl(D)->getHandleRef() == DeviceUR)
284-
return getSyclObjImpl(D).get();
283+
for (device_impl *D : MDevices)
284+
if (D->getHandleRef() == DeviceUR)
285+
return D;
285286

286287
return nullptr;
287288
}
@@ -301,8 +302,8 @@ bool context_impl::isBufferLocationSupported() const {
301302
return MSupportBufferLocationByDevices == Supported ? true : false;
302303
// Check that devices within context have support of buffer location
303304
MSupportBufferLocationByDevices = Supported;
304-
for (auto &Device : MDevices) {
305-
if (!Device.has_extension("cl_intel_mem_alloc_buffer_location")) {
305+
for (device_impl *Device : MDevices) {
306+
if (!Device->has_extension("cl_intel_mem_alloc_buffer_location")) {
306307
MSupportBufferLocationByDevices = NotSupported;
307308
break;
308309
}

sycl/source/detail/context_impl.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
4747
/// \param DeviceList is a list of SYCL device instances.
4848
/// \param AsyncHandler is an instance of async_handler.
4949
/// \param PropList is an instance of property_list.
50-
context_impl(const std::vector<sycl::device> DeviceList,
51-
async_handler AsyncHandler, const property_list &PropList,
52-
private_tag);
50+
context_impl(devices_range DeviceList, async_handler AsyncHandler,
51+
const property_list &PropList, private_tag);
5352

5453
/// Construct a context_impl using plug-in interoperability handle.
5554
///
@@ -62,9 +61,8 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
6261
/// \param OwnedByRuntime is the flag if ownership is kept by user or
6362
/// transferred to runtime
6463
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
65-
adapter_impl &Adapter,
66-
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
67-
private_tag);
64+
adapter_impl &Adapter, devices_range DeviceList,
65+
bool OwnedByRuntime, private_tag);
6866

6967
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
7068
adapter_impl &Adapter, private_tag tag)
@@ -94,10 +92,10 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
9492
const async_handler &get_async_handler() const;
9593

9694
/// \return the Adapter associated with the platform of this context.
97-
adapter_impl &getAdapter() const { return MPlatform->getAdapter(); }
95+
adapter_impl &getAdapter() const { return MPlatform.getAdapter(); }
9896

9997
/// \return the PlatformImpl associated with this context.
100-
platform_impl &getPlatformImpl() const { return *MPlatform; }
98+
platform_impl &getPlatformImpl() const { return MPlatform; }
10199

102100
/// Queries this context for information.
103101
///
@@ -191,10 +189,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
191189
}
192190

193191
// Returns the backend of this context
194-
backend getBackend() const {
195-
assert(MPlatform && "MPlatform must be not null");
196-
return MPlatform->getBackend();
197-
}
192+
backend getBackend() const { return MPlatform.getBackend(); }
198193

199194
/// Given a UR device, returns the matching shared_ptr<device_impl>
200195
/// within this context. May return nullptr if no match discovered.
@@ -262,10 +257,9 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
262257
private:
263258
bool MOwnedByRuntime;
264259
async_handler MAsyncHandler;
265-
std::vector<device> MDevices;
260+
std::vector<device_impl *> MDevices;
266261
ur_context_handle_t MContext;
267-
// TODO: Make it a reference instead, but that needs a bit more refactoring:
268-
std::shared_ptr<platform_impl> MPlatform;
262+
platform_impl &MPlatform;
269263
property_list MPropList;
270264
CachedLibProgramsT MCachedLibPrograms;
271265
std::mutex MCachedLibProgramsMutex;

0 commit comments

Comments
 (0)