-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Vulkan][Refactor] Move ownership of per-CPU-thread objects to Vulkan…
…DeviceAPI (#8196) * [Vulkan][Refactor] Moved VulkanStream ownership from VulkanThreadEntry to VulkanDevice - Implemented ThreadMap, a container for per-thread objects. Unlike dmlc::ThreadLocalStore, ThreadMap is intended for use as a non-static thread-specific lookup. - Added ThreadMap<VulkanStream> as a member to VulkanDevice, updated all uses. * [Vulkan][Refactor] Pulled VulkanBuffer allocation/deallocation into constructor/destructor. - VulkanBuffer owns the VkBuffer and VkDeviceMemory that it allocates, and deallocates on destruction. - VulkanHostVisibleBuffer owns a VulkanBuffer, and additional calls vkUnmapMemory on destruction. * [Vulkan][Refactor] Move the VulkanStagingBuffer to be owned by the VulkanDevice - Previously, was owned by VulkanThreadEntry, so any use required looking up both the thread entry and the device. Now, thread-specific lookup is handled in the VulkanDevice class. * [Vulkan][Refactor] Move ownership of per-thread uniform buffer to VulkanDevice - Previously, VulkanUniformBuffer was owned by VulkanThreadEntry, so any use required looking up both the thread entry and the device. Now, thread-specific lookup is handled in the VulkanDevice class. * [Vulkan][Refactor] Moved ownership of per-thread workspace pool to VulkanDeviceAPI - Previously, the WorkspacePool was owned by VulkanThreadEntry, and required a lookup from VulkanDeviceAPI::AllocWorkspace. As a result, non-global VulkanDeviceAPI would interact with each other. * [Vulkan][Refactor] Moved ownership of per-thread active device id to VulkanDeviceAPI - Previously, the active device was owned by VulkanThreadEntry, so lookups to multiple global variables were required. Now, everything goes from the VulkanDeviceAPI. - Removed VulkanThreadEntry, as all functionality has been moved to either VulkanDevice or VulkanDeviceAPI. Co-authored-by: Eric Lunderberg <[email protected]>
- Loading branch information
1 parent
657af3a
commit f906fa8
Showing
12 changed files
with
689 additions
and
374 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#ifndef TVM_RUNTIME_THREAD_MAP_H_ | ||
#define TVM_RUNTIME_THREAD_MAP_H_ | ||
|
||
#include <functional> | ||
#include <memory> | ||
#include <mutex> | ||
#include <shared_mutex> | ||
#include <thread> | ||
#include <unordered_map> | ||
#include <utility> | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
||
/*! \brief Container to hold one value per thread | ||
* | ||
* Similar to thread_local, but intended for use as a non-static or | ||
* non-block variable, such as class member variables. All member | ||
* functions are thread-safe to call. If only the current thread's | ||
* value is accessed, no additional synchronization is required. If | ||
* another thread's stored values are accessed, external | ||
* synchronization may be required. | ||
* | ||
* Calls that only require access to already-existing values will not | ||
* block each other. Calls that require constructing a new value will | ||
* block any other calls. | ||
* | ||
* \tparam T The object type to be held. For instantiation of | ||
* ThreadMap<T> and for calls to ThreadMap<T>::Get, only a forward | ||
* declaration is required. For calls to ThreadMap<T>::GetOrMake, a | ||
* full class definition is required. | ||
*/ | ||
template <typename T> | ||
class ThreadMap { | ||
public: | ||
ThreadMap() {} | ||
|
||
/*! \brief Return the current thread's stored object, if it exists. | ||
* | ||
* \return If it exists, a pointer to the stored object. Otherwise, | ||
* returns nullptr. | ||
*/ | ||
const T* Get() const { return this->Get(std::this_thread::get_id()); } | ||
|
||
/*! \brief Return the stored object for a given thread, if it exists. | ||
* | ||
* \param id The thread whose object should be returned. | ||
* | ||
* \return If it exists, a pointer to the stored object. Otherwise, | ||
* returns nullptr. | ||
*/ | ||
const T* Get(std::thread::id id) const { | ||
std::shared_lock<std::shared_timed_mutex> lock(mutex_); | ||
auto res = values_.find(id); | ||
if (res == values_.end()) { | ||
return nullptr; | ||
} else { | ||
return res->second.get(); | ||
} | ||
} | ||
|
||
/*! \brief Return the current thread's stored object, if it exists. | ||
* | ||
* \return If it exists, a pointer to the stored object. Otherwise, | ||
* returns nullptr. | ||
*/ | ||
T* Get() { return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get()); } | ||
|
||
/*! \brief Return the stored object for a given thread, if it exists. | ||
* | ||
* \param id The thread whose object should be returned. | ||
* | ||
* \return If it exists, a pointer to the stored object. Otherwise, | ||
* returns nullptr. | ||
*/ | ||
T* Get(std::thread::id id) { | ||
return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get(id)); | ||
} | ||
|
||
/*! \brief Return the current thread's stored object, making it if | ||
* necessary. | ||
* | ||
* Since this method can modify the stored map, there is no | ||
* non-const version available. | ||
* | ||
* \tparam Params Types of the stored object's constructor arguments | ||
* | ||
* \return A reference to the stored object | ||
*/ | ||
template <typename... Params> | ||
T& GetOrMake(Params&&... params) { | ||
return GetOrMake(std::this_thread::get_id(), std::forward<Params>(params)...); | ||
} | ||
|
||
/*! \brief Return the stored object for a given thread, making it if | ||
* necessary | ||
* | ||
* Since this method can modify the stored map, there is no | ||
* non-const version available. | ||
* | ||
* \tparam Params Types of the stored object's constructor arguments | ||
* | ||
* \param id The thread whose object should be returned. | ||
* | ||
* \param params Arguments to the stored object's constructor. Only | ||
* used if the specified thread does not currently exist in the map. | ||
* | ||
* \return A reference to the stored object | ||
*/ | ||
template <typename... Params> | ||
T& GetOrMake(std::thread::id id, Params&&... params) { | ||
// Try to get stored value first, which would only require shared | ||
// access. | ||
if (T* output = Get(id)) { | ||
return *output; | ||
} | ||
|
||
// Not in map, need exclusive lock to write | ||
std::unique_lock<std::shared_timed_mutex> lock(mutex_); | ||
|
||
// Check again, in case another thread got the unique lock first | ||
// and already constructed the object. | ||
auto res = values_.find(id); | ||
if (res != values_.end()) { | ||
return *res->second; | ||
} | ||
|
||
// No value exists, make one and return it. | ||
std::unique_ptr<T>& new_val = values_[id] = | ||
std::make_unique<T>(std::forward<Params>(params)...); | ||
return *new_val; | ||
} | ||
|
||
/*! \brief Clears all values held by the ThreadMap | ||
* | ||
* Calling Clear() invalidates any pointers/references previously | ||
* returned by Get/GetOrMake. | ||
* | ||
*/ | ||
void Clear() { | ||
std::unique_lock<std::shared_timed_mutex> lock(mutex_); | ||
values_.clear(); | ||
} | ||
|
||
private: | ||
//! \brief Mutex to protect values_ | ||
mutable std::shared_timed_mutex mutex_; | ||
|
||
//! \brief Map containing stored values | ||
std::unordered_map<std::thread::id, std::unique_ptr<T>> values_; | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace tvm | ||
|
||
#endif // TVM_RUNTIME_THREAD_MAP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.