Skip to content

Commit f906fa8

Browse files
authored
[Vulkan][Refactor] Move ownership of per-CPU-thread objects to VulkanDeviceAPI (#8196)
* [Vulkan][Refactor] Moved VulkanStream ownership from VulkanThreadEntry to VulkanDevice - Implemented ThreadMap, a container for per-thread objects. Unlike dmlc::ThreadLocalStore, ThreadMap is intended for use as a non-static thread-specific lookup. - Added ThreadMap<VulkanStream> as a member to VulkanDevice, updated all uses. * [Vulkan][Refactor] Pulled VulkanBuffer allocation/deallocation into constructor/destructor. - VulkanBuffer owns the VkBuffer and VkDeviceMemory that it allocates, and deallocates on destruction. - VulkanHostVisibleBuffer owns a VulkanBuffer, and additional calls vkUnmapMemory on destruction. * [Vulkan][Refactor] Move the VulkanStagingBuffer to be owned by the VulkanDevice - Previously, was owned by VulkanThreadEntry, so any use required looking up both the thread entry and the device. Now, thread-specific lookup is handled in the VulkanDevice class. * [Vulkan][Refactor] Move ownership of per-thread uniform buffer to VulkanDevice - Previously, VulkanUniformBuffer was owned by VulkanThreadEntry, so any use required looking up both the thread entry and the device. Now, thread-specific lookup is handled in the VulkanDevice class. * [Vulkan][Refactor] Moved ownership of per-thread workspace pool to VulkanDeviceAPI - Previously, the WorkspacePool was owned by VulkanThreadEntry, and required a lookup from VulkanDeviceAPI::AllocWorkspace. As a result, non-global VulkanDeviceAPI would interact with each other. * [Vulkan][Refactor] Moved ownership of per-thread active device id to VulkanDeviceAPI - Previously, the active device was owned by VulkanThreadEntry, so lookups to multiple global variables were required. Now, everything goes from the VulkanDeviceAPI. - Removed VulkanThreadEntry, as all functionality has been moved to either VulkanDevice or VulkanDeviceAPI. Co-authored-by: Eric Lunderberg <[email protected]>
1 parent 657af3a commit f906fa8

12 files changed

+689
-374
lines changed

src/runtime/thread_map.h

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_RUNTIME_THREAD_MAP_H_
21+
#define TVM_RUNTIME_THREAD_MAP_H_
22+
23+
#include <functional>
24+
#include <memory>
25+
#include <mutex>
26+
#include <shared_mutex>
27+
#include <thread>
28+
#include <unordered_map>
29+
#include <utility>
30+
31+
namespace tvm {
32+
namespace runtime {
33+
34+
/*! \brief Container to hold one value per thread
35+
*
36+
* Similar to thread_local, but intended for use as a non-static or
37+
* non-block variable, such as class member variables. All member
38+
* functions are thread-safe to call. If only the current thread's
39+
* value is accessed, no additional synchronization is required. If
40+
* another thread's stored values are accessed, external
41+
* synchronization may be required.
42+
*
43+
* Calls that only require access to already-existing values will not
44+
* block each other. Calls that require constructing a new value will
45+
* block any other calls.
46+
*
47+
* \tparam T The object type to be held. For instantiation of
48+
* ThreadMap<T> and for calls to ThreadMap<T>::Get, only a forward
49+
* declaration is required. For calls to ThreadMap<T>::GetOrMake, a
50+
* full class definition is required.
51+
*/
52+
template <typename T>
53+
class ThreadMap {
54+
public:
55+
ThreadMap() {}
56+
57+
/*! \brief Return the current thread's stored object, if it exists.
58+
*
59+
* \return If it exists, a pointer to the stored object. Otherwise,
60+
* returns nullptr.
61+
*/
62+
const T* Get() const { return this->Get(std::this_thread::get_id()); }
63+
64+
/*! \brief Return the stored object for a given thread, if it exists.
65+
*
66+
* \param id The thread whose object should be returned.
67+
*
68+
* \return If it exists, a pointer to the stored object. Otherwise,
69+
* returns nullptr.
70+
*/
71+
const T* Get(std::thread::id id) const {
72+
std::shared_lock<std::shared_timed_mutex> lock(mutex_);
73+
auto res = values_.find(id);
74+
if (res == values_.end()) {
75+
return nullptr;
76+
} else {
77+
return res->second.get();
78+
}
79+
}
80+
81+
/*! \brief Return the current thread's stored object, if it exists.
82+
*
83+
* \return If it exists, a pointer to the stored object. Otherwise,
84+
* returns nullptr.
85+
*/
86+
T* Get() { return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get()); }
87+
88+
/*! \brief Return the stored object for a given thread, if it exists.
89+
*
90+
* \param id The thread whose object should be returned.
91+
*
92+
* \return If it exists, a pointer to the stored object. Otherwise,
93+
* returns nullptr.
94+
*/
95+
T* Get(std::thread::id id) {
96+
return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get(id));
97+
}
98+
99+
/*! \brief Return the current thread's stored object, making it if
100+
* necessary.
101+
*
102+
* Since this method can modify the stored map, there is no
103+
* non-const version available.
104+
*
105+
* \tparam Params Types of the stored object's constructor arguments
106+
*
107+
* \return A reference to the stored object
108+
*/
109+
template <typename... Params>
110+
T& GetOrMake(Params&&... params) {
111+
return GetOrMake(std::this_thread::get_id(), std::forward<Params>(params)...);
112+
}
113+
114+
/*! \brief Return the stored object for a given thread, making it if
115+
* necessary
116+
*
117+
* Since this method can modify the stored map, there is no
118+
* non-const version available.
119+
*
120+
* \tparam Params Types of the stored object's constructor arguments
121+
*
122+
* \param id The thread whose object should be returned.
123+
*
124+
* \param params Arguments to the stored object's constructor. Only
125+
* used if the specified thread does not currently exist in the map.
126+
*
127+
* \return A reference to the stored object
128+
*/
129+
template <typename... Params>
130+
T& GetOrMake(std::thread::id id, Params&&... params) {
131+
// Try to get stored value first, which would only require shared
132+
// access.
133+
if (T* output = Get(id)) {
134+
return *output;
135+
}
136+
137+
// Not in map, need exclusive lock to write
138+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
139+
140+
// Check again, in case another thread got the unique lock first
141+
// and already constructed the object.
142+
auto res = values_.find(id);
143+
if (res != values_.end()) {
144+
return *res->second;
145+
}
146+
147+
// No value exists, make one and return it.
148+
std::unique_ptr<T>& new_val = values_[id] =
149+
std::make_unique<T>(std::forward<Params>(params)...);
150+
return *new_val;
151+
}
152+
153+
/*! \brief Clears all values held by the ThreadMap
154+
*
155+
* Calling Clear() invalidates any pointers/references previously
156+
* returned by Get/GetOrMake.
157+
*
158+
*/
159+
void Clear() {
160+
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
161+
values_.clear();
162+
}
163+
164+
private:
165+
//! \brief Mutex to protect values_
166+
mutable std::shared_timed_mutex mutex_;
167+
168+
//! \brief Map containing stored values
169+
std::unordered_map<std::thread::id, std::unique_ptr<T>> values_;
170+
};
171+
172+
} // namespace runtime
173+
} // namespace tvm
174+
175+
#endif // TVM_RUNTIME_THREAD_MAP_H_

src/runtime/vulkan/vulkan_buffer.cc

+111-13
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,125 @@
1919

2020
#include "vulkan_buffer.h"
2121

22+
#include <utility>
23+
2224
#include "vulkan_device_api.h"
23-
#include "vulkan_thread_entry.h"
2425

2526
namespace tvm {
2627
namespace runtime {
2728
namespace vulkan {
2829

29-
void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) {
30-
if (buf && buf->vk_buf) {
31-
if (buf->host_addr != nullptr) {
32-
vkUnmapMemory(buf->device, buf->vk_buf->memory);
33-
}
34-
if (buf->vk_buf->memory != VK_NULL_HANDLE) {
35-
vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr);
36-
}
37-
if (buf->vk_buf->buffer != VK_NULL_HANDLE) {
38-
vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr);
30+
VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) {
31+
VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};
32+
info.size = nbytes;
33+
// Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to
34+
// specify the queue families.
35+
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
36+
info.usage = usage;
37+
return info;
38+
}
39+
40+
VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage,
41+
uint32_t mem_type_index)
42+
: device_(device) {
43+
// Create a buffer
44+
VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage);
45+
VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer));
46+
47+
// Allocate memory
48+
VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO};
49+
mem_info.allocationSize = buffer_info.size;
50+
mem_info.memoryTypeIndex = mem_type_index;
51+
52+
VkMemoryDedicatedAllocateInfoKHR dedicated_info = {
53+
VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR};
54+
55+
bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize);
56+
if (use_dedicated_allocation) {
57+
dedicated_info.buffer = buffer;
58+
mem_info.pNext = &dedicated_info;
59+
}
60+
61+
VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory));
62+
63+
// Bind the buffer to the allocated memory
64+
VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0));
65+
}
66+
67+
VulkanBuffer::~VulkanBuffer() {
68+
if (buffer) {
69+
vkDestroyBuffer(device_, buffer, nullptr);
70+
}
71+
if (memory) {
72+
vkFreeMemory(device_, memory, nullptr);
73+
}
74+
}
75+
76+
VulkanBuffer::VulkanBuffer(VulkanBuffer&& other)
77+
: device_(other.device_), buffer(other.buffer), memory(other.memory) {
78+
other.device_ = VK_NULL_HANDLE;
79+
other.buffer = VK_NULL_HANDLE;
80+
other.memory = VK_NULL_HANDLE;
81+
}
82+
83+
VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) {
84+
std::swap(device_, other.device_);
85+
std::swap(buffer, other.buffer);
86+
std::swap(memory, other.memory);
87+
return *this;
88+
}
89+
90+
bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer,
91+
VkDeviceSize* nbytes) {
92+
if (device.get_buffer_memory_requirements_2_functions) {
93+
// Which buffer to request information about
94+
VkBufferMemoryRequirementsInfo2KHR req_info2 = {
95+
VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR};
96+
req_info2.buffer = buffer;
97+
98+
// What information to request
99+
VkMemoryDedicatedRequirementsKHR dedicated_req;
100+
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
101+
dedicated_req.pNext = 0;
102+
103+
VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR};
104+
req2.pNext = &dedicated_req;
105+
106+
device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
107+
device, &req_info2, &req2);
108+
if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) {
109+
*nbytes = req2.memoryRequirements.size;
110+
return true;
39111
}
40-
buf->host_addr = nullptr;
41-
delete buf->vk_buf;
42112
}
113+
114+
return false;
115+
}
116+
117+
VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes,
118+
VkBufferUsageFlags usage, uint32_t mem_type_index)
119+
: vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) {
120+
VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr));
121+
}
122+
123+
VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() {
124+
if (host_addr) {
125+
vkUnmapMemory(vk_buf.device_, vk_buf.memory);
126+
}
127+
}
128+
129+
VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other)
130+
: vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) {
131+
other.host_addr = nullptr;
132+
other.size = 0;
133+
}
134+
135+
VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) {
136+
std::swap(vk_buf, other.vk_buf);
137+
std::swap(host_addr, other.host_addr);
138+
std::swap(size, other.size);
139+
140+
return *this;
43141
}
44142

45143
} // namespace vulkan

0 commit comments

Comments
 (0)