diff --git a/source/adapters/level_zero/v2/common.hpp b/source/adapters/level_zero/v2/common.hpp new file mode 100644 index 0000000000..cec7a4dc97 --- /dev/null +++ b/source/adapters/level_zero/v2/common.hpp @@ -0,0 +1,97 @@ +//===--------- common.hpp - Level Zero Adapter ---------------------------===// +// +// Copyright (C) 2024 Intel Corporation +// +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM +// Exceptions. See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "../common.hpp" + +namespace v2 { + +namespace raii { + +template +struct ze_handle_wrapper { + ze_handle_wrapper(bool ownZeHandle = true) + : handle(nullptr), ownZeHandle(ownZeHandle) {} + + ze_handle_wrapper(ZeHandleT handle, bool ownZeHandle = true) + : handle(handle), ownZeHandle(ownZeHandle) {} + + ze_handle_wrapper(const ze_handle_wrapper &) = delete; + ze_handle_wrapper &operator=(const ze_handle_wrapper &) = delete; + + ze_handle_wrapper(ze_handle_wrapper &&other) + : handle(other.handle), ownZeHandle(other.ownZeHandle) { + other.handle = nullptr; + } + + ze_handle_wrapper &operator=(ze_handle_wrapper &&other) { + if (this == &other) { + return *this; + } + + if (handle) { + reset(); + } + handle = other.handle; + ownZeHandle = other.ownZeHandle; + other.handle = nullptr; + return *this; + } + + ~ze_handle_wrapper() { + try { + reset(); + } catch (...) { + } + } + + void reset() { + if (!handle) { + return; + } + + auto zeResult = ZE_CALL_NOCHECK(destroy, (handle)); + // Gracefully handle the case that L0 was already unloaded. + if (zeResult && zeResult != ZE_RESULT_ERROR_UNINITIALIZED) + throw ze2urResult(zeResult); + + handle = nullptr; + } + + ZeHandleT release() { + auto handle = this->handle; + this->handle = nullptr; + return handle; + } + + ZeHandleT get() const { return handle; } + + ZeHandleT *ptr() { return &handle; } + +private: + ZeHandleT handle; + bool ownZeHandle; +}; + +using ze_kernel_handle_t = + ze_handle_wrapper<::ze_kernel_handle_t, zeKernelDestroy>; + +using ze_event_handle_t = + ze_handle_wrapper<::ze_event_handle_t, zeEventDestroy>; + +using ze_event_pool_handle_t = + ze_handle_wrapper<::ze_event_pool_handle_t, zeEventPoolDestroy>; + +} // namespace raii +} // namespace v2 diff --git a/source/adapters/level_zero/v2/event_provider_counter.cpp b/source/adapters/level_zero/v2/event_provider_counter.cpp index 14e33a5700..220974d405 100644 --- a/source/adapters/level_zero/v2/event_provider_counter.cpp +++ b/source/adapters/level_zero/v2/event_provider_counter.cpp @@ -33,12 +33,6 @@ provider_counter::provider_counter(ur_platform_handle_t platform, (ZEL_HANDLE_DEVICE, device->ZeDevice, (void **)&translatedDevice)); } -provider_counter::~provider_counter() { - for (auto &e : freelist) { - ZE_CALL_NOCHECK(zeEventDestroy, (e)); - } -} - event_allocation provider_counter::allocate() { if (freelist.empty()) { ZeStruct desc; @@ -54,11 +48,11 @@ event_allocation provider_counter::allocate() { freelist.emplace_back(handle); } - auto event = freelist.back(); + auto event = std::move(freelist.back()); freelist.pop_back(); return {event_type::EVENT_COUNTER, - event_borrowed(event, [this](ze_event_handle_t handle) { + event_borrowed(event.release(), [this](ze_event_handle_t handle) { freelist.push_back(handle); })}; } diff --git a/source/adapters/level_zero/v2/event_provider_counter.hpp b/source/adapters/level_zero/v2/event_provider_counter.hpp index 60a8107469..65fdcd5ff2 100644 --- a/source/adapters/level_zero/v2/event_provider_counter.hpp +++ b/source/adapters/level_zero/v2/event_provider_counter.hpp @@ -35,7 +35,6 @@ class provider_counter : public event_provider { public: provider_counter(ur_platform_handle_t platform, ur_context_handle_t, ur_device_handle_t); - ~provider_counter() override; event_allocation allocate() override; ur_device_handle_t device() override; @@ -48,7 +47,7 @@ class provider_counter : public event_provider { zexCounterBasedEventCreate eventCreateFunc; - std::vector freelist; + std::vector freelist; }; } // namespace v2 diff --git a/source/adapters/level_zero/v2/event_provider_normal.cpp b/source/adapters/level_zero/v2/event_provider_normal.cpp index 4e8287b36c..c5d3d61af6 100644 --- a/source/adapters/level_zero/v2/event_provider_normal.cpp +++ b/source/adapters/level_zero/v2/event_provider_normal.cpp @@ -38,7 +38,7 @@ provider_pool::provider_pool(ur_context_handle_t context, ZE2UR_CALL_THROWS(zeEventPoolCreate, (context->ZeContext, &desc, 1, const_cast(&device->ZeDevice), - &pool)); + pool.ptr())); freelist.resize(EVENTS_BURST); for (int i = 0; i < EVENTS_BURST; ++i) { @@ -46,25 +46,19 @@ provider_pool::provider_pool(ur_context_handle_t context, desc.index = i; desc.signal = 0; desc.wait = 0; - ZE2UR_CALL_THROWS(zeEventCreate, (pool, &desc, &freelist[i])); + ZE2UR_CALL_THROWS(zeEventCreate, (pool.get(), &desc, freelist[i].ptr())); } } -provider_pool::~provider_pool() { - for (auto e : freelist) { - ZE_CALL_NOCHECK(zeEventDestroy, (e)); - } - ZE_CALL_NOCHECK(zeEventPoolDestroy, (pool)); -} - event_borrowed provider_pool::allocate() { if (freelist.empty()) { return nullptr; } - ze_event_handle_t e = freelist.back(); + auto e = std::move(freelist.back()); freelist.pop_back(); - return event_borrowed( - e, [this](ze_event_handle_t handle) { freelist.push_back(handle); }); + return event_borrowed(e.release(), [this](ze_event_handle_t handle) { + freelist.push_back(handle); + }); } size_t provider_pool::nfree() const { return freelist.size(); } diff --git a/source/adapters/level_zero/v2/event_provider_normal.hpp b/source/adapters/level_zero/v2/event_provider_normal.hpp index 4ab72ccaed..ffc2373ce5 100644 --- a/source/adapters/level_zero/v2/event_provider_normal.hpp +++ b/source/adapters/level_zero/v2/event_provider_normal.hpp @@ -34,16 +34,13 @@ class provider_pool { public: provider_pool(ur_context_handle_t, ur_device_handle_t, event_type, queue_type); - ~provider_pool(); event_borrowed allocate(); size_t nfree() const; private: - // TODO: use a RAII wrapper for the pool handle - ze_event_pool_handle_t pool; - - std::vector freelist; + raii::ze_event_pool_handle_t pool; + std::vector freelist; }; class provider_normal : public event_provider {