Skip to content

Commit

Permalink
[Vulkan][Refactor] Move ownership of per-CPU-thread objects to Vulkan…
Browse files Browse the repository at this point in the history
…DeviceAPI (#8196)

* [Vulkan][Refactor] Moved VulkanStream ownership from VulkanThreadEntry to VulkanDevice

- Implemented ThreadMap, a container for per-thread objects.  Unlike
  dmlc::ThreadLocalStore, ThreadMap is intended for use as a
  non-static thread-specific lookup.

- Added ThreadMap<VulkanStream> as a member to VulkanDevice, updated
  all uses.

* [Vulkan][Refactor] Pulled VulkanBuffer allocation/deallocation into constructor/destructor.

- VulkanBuffer owns the VkBuffer and VkDeviceMemory that it allocates,
  and deallocates on destruction.

- VulkanHostVisibleBuffer owns a VulkanBuffer, and additional calls
  vkUnmapMemory on destruction.

* [Vulkan][Refactor] Move the VulkanStagingBuffer to be owned by the VulkanDevice

- Previously, was owned by VulkanThreadEntry, so any use required
  looking up both the thread entry and the device.  Now,
  thread-specific lookup is handled in the VulkanDevice class.

* [Vulkan][Refactor] Move ownership of per-thread uniform buffer to VulkanDevice

- Previously, VulkanUniformBuffer was owned by VulkanThreadEntry, so
  any use required looking up both the thread entry and the device.
  Now, thread-specific lookup is handled in the VulkanDevice class.

* [Vulkan][Refactor] Moved ownership of per-thread workspace pool to VulkanDeviceAPI

- Previously, the WorkspacePool was owned by VulkanThreadEntry, and
  required a lookup from VulkanDeviceAPI::AllocWorkspace.  As a
  result, non-global VulkanDeviceAPI would interact with each other.

* [Vulkan][Refactor] Moved ownership of per-thread active device id to VulkanDeviceAPI

- Previously, the active device was owned by VulkanThreadEntry, so
  lookups to multiple global variables were required.  Now, everything
  goes from the VulkanDeviceAPI.

- Removed VulkanThreadEntry, as all functionality has been moved to
  either VulkanDevice or VulkanDeviceAPI.

Co-authored-by: Eric Lunderberg <[email protected]>
  • Loading branch information
Lunderberg and Lunderberg authored Jun 11, 2021
1 parent 657af3a commit f906fa8
Show file tree
Hide file tree
Showing 12 changed files with 689 additions and 374 deletions.
175 changes: 175 additions & 0 deletions src/runtime/thread_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_RUNTIME_THREAD_MAP_H_
#define TVM_RUNTIME_THREAD_MAP_H_

#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <thread>
#include <unordered_map>
#include <utility>

namespace tvm {
namespace runtime {

/*! \brief Container to hold one value per thread
*
* Similar to thread_local, but intended for use as a non-static or
* non-block variable, such as class member variables. All member
* functions are thread-safe to call. If only the current thread's
* value is accessed, no additional synchronization is required. If
* another thread's stored values are accessed, external
* synchronization may be required.
*
* Calls that only require access to already-existing values will not
* block each other. Calls that require constructing a new value will
* block any other calls.
*
* \tparam T The object type to be held. For instantiation of
* ThreadMap<T> and for calls to ThreadMap<T>::Get, only a forward
* declaration is required. For calls to ThreadMap<T>::GetOrMake, a
* full class definition is required.
*/
template <typename T>
class ThreadMap {
public:
ThreadMap() {}

/*! \brief Return the current thread's stored object, if it exists.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
const T* Get() const { return this->Get(std::this_thread::get_id()); }

/*! \brief Return the stored object for a given thread, if it exists.
*
* \param id The thread whose object should be returned.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
const T* Get(std::thread::id id) const {
std::shared_lock<std::shared_timed_mutex> lock(mutex_);
auto res = values_.find(id);
if (res == values_.end()) {
return nullptr;
} else {
return res->second.get();
}
}

/*! \brief Return the current thread's stored object, if it exists.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
T* Get() { return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get()); }

/*! \brief Return the stored object for a given thread, if it exists.
*
* \param id The thread whose object should be returned.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
T* Get(std::thread::id id) {
return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get(id));
}

/*! \brief Return the current thread's stored object, making it if
* necessary.
*
* Since this method can modify the stored map, there is no
* non-const version available.
*
* \tparam Params Types of the stored object's constructor arguments
*
* \return A reference to the stored object
*/
template <typename... Params>
T& GetOrMake(Params&&... params) {
return GetOrMake(std::this_thread::get_id(), std::forward<Params>(params)...);
}

/*! \brief Return the stored object for a given thread, making it if
* necessary
*
* Since this method can modify the stored map, there is no
* non-const version available.
*
* \tparam Params Types of the stored object's constructor arguments
*
* \param id The thread whose object should be returned.
*
* \param params Arguments to the stored object's constructor. Only
* used if the specified thread does not currently exist in the map.
*
* \return A reference to the stored object
*/
template <typename... Params>
T& GetOrMake(std::thread::id id, Params&&... params) {
// Try to get stored value first, which would only require shared
// access.
if (T* output = Get(id)) {
return *output;
}

// Not in map, need exclusive lock to write
std::unique_lock<std::shared_timed_mutex> lock(mutex_);

// Check again, in case another thread got the unique lock first
// and already constructed the object.
auto res = values_.find(id);
if (res != values_.end()) {
return *res->second;
}

// No value exists, make one and return it.
std::unique_ptr<T>& new_val = values_[id] =
std::make_unique<T>(std::forward<Params>(params)...);
return *new_val;
}

/*! \brief Clears all values held by the ThreadMap
*
* Calling Clear() invalidates any pointers/references previously
* returned by Get/GetOrMake.
*
*/
void Clear() {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
values_.clear();
}

private:
//! \brief Mutex to protect values_
mutable std::shared_timed_mutex mutex_;

//! \brief Map containing stored values
std::unordered_map<std::thread::id, std::unique_ptr<T>> values_;
};

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_THREAD_MAP_H_
124 changes: 111 additions & 13 deletions src/runtime/vulkan/vulkan_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,125 @@

#include "vulkan_buffer.h"

#include <utility>

#include "vulkan_device_api.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
namespace vulkan {

void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) {
if (buf && buf->vk_buf) {
if (buf->host_addr != nullptr) {
vkUnmapMemory(buf->device, buf->vk_buf->memory);
}
if (buf->vk_buf->memory != VK_NULL_HANDLE) {
vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr);
}
if (buf->vk_buf->buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr);
VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) {
VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};
info.size = nbytes;
// Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to
// specify the queue families.
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
info.usage = usage;
return info;
}

VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage,
uint32_t mem_type_index)
: device_(device) {
// Create a buffer
VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage);
VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer));

// Allocate memory
VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO};
mem_info.allocationSize = buffer_info.size;
mem_info.memoryTypeIndex = mem_type_index;

VkMemoryDedicatedAllocateInfoKHR dedicated_info = {
VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR};

bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize);
if (use_dedicated_allocation) {
dedicated_info.buffer = buffer;
mem_info.pNext = &dedicated_info;
}

VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory));

// Bind the buffer to the allocated memory
VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0));
}

VulkanBuffer::~VulkanBuffer() {
if (buffer) {
vkDestroyBuffer(device_, buffer, nullptr);
}
if (memory) {
vkFreeMemory(device_, memory, nullptr);
}
}

VulkanBuffer::VulkanBuffer(VulkanBuffer&& other)
: device_(other.device_), buffer(other.buffer), memory(other.memory) {
other.device_ = VK_NULL_HANDLE;
other.buffer = VK_NULL_HANDLE;
other.memory = VK_NULL_HANDLE;
}

VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) {
std::swap(device_, other.device_);
std::swap(buffer, other.buffer);
std::swap(memory, other.memory);
return *this;
}

bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer,
VkDeviceSize* nbytes) {
if (device.get_buffer_memory_requirements_2_functions) {
// Which buffer to request information about
VkBufferMemoryRequirementsInfo2KHR req_info2 = {
VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR};
req_info2.buffer = buffer;

// What information to request
VkMemoryDedicatedRequirementsKHR dedicated_req;
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
dedicated_req.pNext = 0;

VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR};
req2.pNext = &dedicated_req;

device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
device, &req_info2, &req2);
if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) {
*nbytes = req2.memoryRequirements.size;
return true;
}
buf->host_addr = nullptr;
delete buf->vk_buf;
}

return false;
}

VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes,
VkBufferUsageFlags usage, uint32_t mem_type_index)
: vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) {
VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr));
}

VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() {
if (host_addr) {
vkUnmapMemory(vk_buf.device_, vk_buf.memory);
}
}

VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other)
: vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) {
other.host_addr = nullptr;
other.size = 0;
}

VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) {
std::swap(vk_buf, other.vk_buf);
std::swap(host_addr, other.host_addr);
std::swap(size, other.size);

return *this;
}

} // namespace vulkan
Expand Down
Loading

0 comments on commit f906fa8

Please sign in to comment.