Skip to content

Commit

Permalink
Merge pull request #2686 from igchor/fix_context_dtor
Browse files Browse the repository at this point in the history
[L0 v2] reorder context members
  • Loading branch information
pbalcer authored Feb 11, 2025
2 parents 6cd6446 + d246b15 commit 396f568
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 21 deletions.
4 changes: 2 additions & 2 deletions source/adapters/level_zero/v2/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ urCommandBufferCreateExp(ur_context_handle_t context, ur_device_handle_t device,
uint32_t queueGroupOrdinal =
device->QueueGroup[queue_group_type::Compute].ZeOrdinal;
v2::raii::command_list_unique_handle zeCommandList =
context->commandListCache.getRegularCommandList(device->ZeDevice, true,
queueGroupOrdinal, true);
context->getCommandListCache().getRegularCommandList(
device->ZeDevice, true, queueGroupOrdinal, true);

*commandBuffer = new ur_exp_command_buffer_handle_t_(
context, device, std::move(zeCommandList), commandBufferDesc);
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/v2/command_list_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ur_command_list_manager::ur_command_list_manager(
v2::raii::command_list_unique_handle &&commandList, v2::event_flags_t flags,
ur_queue_t_ *queue)
: context(context), device(device),
eventPool(context->eventPoolCache.borrow(device->Id.value(), flags)),
eventPool(context->getEventPoolCache().borrow(device->Id.value(), flags)),
zeCommandList(std::move(commandList)), queue(queue) {
UR_CALL_THROWS(ur::level_zero::urContextRetain(context));
UR_CALL_THROWS(ur::level_zero::urDeviceRetain(device));
Expand Down
5 changes: 2 additions & 3 deletions source/adapters/level_zero/v2/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
uint32_t numDevices,
const ur_device_handle_t *phDevices,
bool ownZeContext)
: commandListCache(hContext),
: hContext(hContext, ownZeContext),
hDevices(phDevices, phDevices + numDevices), commandListCache(hContext),
eventPoolCache(this, phDevices[0]->Platform->getNumDevices(),
[context = this, platform = phDevices[0]->Platform](
DeviceId deviceId, v2::event_flags_t flags)
Expand All @@ -65,8 +66,6 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
nativeEventsPool(this, std::make_unique<v2::provider_normal>(
this, v2::QUEUE_IMMEDIATE,
v2::EVENT_FLAGS_PROFILING_ENABLED)),
hContext(hContext, ownZeContext),
hDevices(phDevices, phDevices + numDevices),
p2pAccessDevices(populateP2PDevices(
phDevices[0]->Platform->getNumDevices(), this->hDevices)),
defaultUSMPool(this, nullptr) {}
Expand Down
13 changes: 9 additions & 4 deletions source/adapters/level_zero/v2/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,31 @@ struct ur_context_handle_t_ : _ur_object {

inline ze_context_handle_t getZeHandle() const { return hContext.get(); }
ur_platform_handle_t getPlatform() const;

const std::vector<ur_device_handle_t> &getDevices() const;
ur_usm_pool_handle_t getDefaultUSMPool();

const std::vector<ur_device_handle_t> &
getP2PDevices(ur_device_handle_t hDevice) const;

v2::event_pool &getNativeEventsPool() { return nativeEventsPool; }
v2::event_pool_cache &getEventPoolCache() { return eventPoolCache; }
v2::command_list_cache_t &getCommandListCache() { return commandListCache; }

// Checks if Device is covered by this context.
// For that the Device or its root devices need to be in the context.
bool isValidDevice(ur_device_handle_t Device) const;

private:
const v2::raii::ze_context_handle_t hContext;
const std::vector<ur_device_handle_t> hDevices;
v2::command_list_cache_t commandListCache;
v2::event_pool_cache eventPoolCache;

// pool used for urEventCreateWithNativeHandle when native handle is NULL
// (uses non-counter based events to allow for signaling from host)
v2::event_pool nativeEventsPool;

private:
const v2::raii::ze_context_handle_t hContext;
const std::vector<ur_device_handle_t> hDevices;

// P2P devices for each device in the context, indexed by device id.
const std::vector<std::vector<ur_device_handle_t>> p2pAccessDevices;

Expand Down
6 changes: 3 additions & 3 deletions source/adapters/level_zero/v2/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ urEventCreateWithNativeHandle(ur_native_handle_t hNativeEvent,
const ur_event_native_properties_t *pProperties,
ur_event_handle_t *phEvent) try {
if (!hNativeEvent) {
assert((hContext->nativeEventsPool.getFlags() & v2::EVENT_FLAGS_COUNTER) ==
0);
assert((hContext->getNativeEventsPool().getFlags() &
v2::EVENT_FLAGS_COUNTER) == 0);

*phEvent = hContext->nativeEventsPool.allocate();
*phEvent = hContext->getNativeEventsPool().allocate();
ZE2UR_CALL(zeEventHostSignal, ((*phEvent)->getZeEvent()));
} else {
*phEvent = new ur_event_handle_t_(hContext, hNativeEvent, pProperties);
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/v2/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void ur_integrated_mem_handle_t::unmapHostPtr(
static ur_result_t synchronousZeCopy(ur_context_handle_t hContext,
ur_device_handle_t hDevice, void *dst,
const void *src, size_t size) {
auto commandList = hContext->commandListCache.getImmediateCommandList(
auto commandList = hContext->getCommandListCache().getImmediateCommandList(
hDevice->ZeDevice, true,
hDevice
->QueueGroup[ur_device_handle_t_::queue_group_info_t::type::Compute]
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/v2/queue_immediate_in_order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ur_queue_immediate_in_order_t::ur_queue_immediate_in_order_t(
: hContext(hContext), hDevice(hDevice), flags(pProps ? pProps->flags : 0),
commandListManager(
hContext, hDevice,
hContext->commandListCache.getImmediateCommandList(
hContext->getCommandListCache().getImmediateCommandList(
hDevice->ZeDevice, true, getZeOrdinal(hDevice),
true /* always enable copy offload */,
ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
Expand Down
13 changes: 7 additions & 6 deletions test/adapters/level_zero/v2/command_list_cache_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,14 @@ TEST_P(CommandListCacheTest, CommandListsAreReusedByQueues) {
}
}

ASSERT_EQ(context->commandListCache.getNumImmediateCommandLists(), 0);
ASSERT_EQ(context->commandListCache.getNumRegularCommandLists(), 0);
ASSERT_EQ(context->getCommandListCache().getNumImmediateCommandLists(),
0);
ASSERT_EQ(context->getCommandListCache().getNumRegularCommandLists(), 0);
} // Queues scope

ASSERT_EQ(context->commandListCache.getNumImmediateCommandLists(),
ASSERT_EQ(context->getCommandListCache().getNumImmediateCommandLists(),
NumUniqueQueueTypes);
ASSERT_EQ(context->commandListCache.getNumRegularCommandLists(), 0);
ASSERT_EQ(context->getCommandListCache().getNumRegularCommandLists(), 0);
}
}

Expand All @@ -229,7 +230,7 @@ TEST_P(CommandListCacheTest, CommandListsCacheIsThreadSafe) {
ASSERT_EQ(urQueueCreate(context, device, &QueueProps, Queue.ptr()),
UR_RESULT_SUCCESS);

ASSERT_LE(context->commandListCache.getNumImmediateCommandLists(),
ASSERT_LE(context->getCommandListCache().getNumImmediateCommandLists(),
NumThreads);
}
});
Expand All @@ -239,6 +240,6 @@ TEST_P(CommandListCacheTest, CommandListsCacheIsThreadSafe) {
Thread.join();
}

ASSERT_LE(context->commandListCache.getNumImmediateCommandLists(),
ASSERT_LE(context->getCommandListCache().getNumImmediateCommandLists(),
NumThreads);
}

0 comments on commit 396f568

Please sign in to comment.