Skip to content

Commit

Permalink
[Vulkan][Refactor] Moved ownership of per-thread active device id to …
Browse files Browse the repository at this point in the history
…VulkanDeviceAPI

- Previously, the active device was owned by VulkanThreadEntry, so
  lookups to multiple global variables were required.  Now, everything
  goes from the VulkanDeviceAPI.

- Removed VulkanThreadEntry, as all functionality has been moved to
  either VulkanDevice or VulkanDeviceAPI.
  • Loading branch information
Lunderberg committed Jun 4, 2021
1 parent 606082f commit 551edf4
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 100 deletions.
1 change: 0 additions & 1 deletion src/runtime/vulkan/vulkan_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <utility>

#include "vulkan_device_api.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
Expand Down
1 change: 0 additions & 1 deletion src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "vulkan_device.h"
#include "vulkan_device_api.h"
#include "vulkan_instance.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
Expand Down
1 change: 0 additions & 1 deletion src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_
#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_

#include <dmlc/thread_local.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/thread_map.h>

Expand Down
16 changes: 14 additions & 2 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <utility>

#include "vulkan_common.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -55,7 +54,20 @@ VulkanDeviceAPI::VulkanDeviceAPI() {

VulkanDeviceAPI::~VulkanDeviceAPI() {}

void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; }
void VulkanDeviceAPI::SetDevice(Device dev) {
ICHECK_EQ(dev.device_type, kDLVulkan)
<< "Active vulkan device cannot be set to non-vulkan device" << dev;

ICHECK_LE(dev.device_id, static_cast<int>(devices_.size()))
<< "Attempted to set active vulkan device to device_id==" << dev.device_id << ", but only "
<< devices_.size() << " devices present";

active_device_id_per_thread.GetOrMake(0) = dev.device_id;
}

int VulkanDeviceAPI::GetActiveDeviceID() { return active_device_id_per_thread.GetOrMake(0); }

VulkanDevice& VulkanDeviceAPI::GetActiveDevice() { return device(GetActiveDeviceID()); }

void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
size_t index = static_cast<size_t>(dev.device_id);
Expand Down
25 changes: 24 additions & 1 deletion src/runtime/vulkan/vulkan_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "../workspace_pool.h"
#include "vulkan_device.h"
#include "vulkan_instance.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -70,6 +69,22 @@ class VulkanDeviceAPI final : public DeviceAPI {
// End of required methods for the DeviceAPI interface

public:
/*! \brief Return the currently active VulkanDevice
*
* The active device can be set using VulkanDeviceAPI::SetDevice.
* Each CPU thread has its own active device, mimicking the
* semantics of cudaSetDevice.
*/
VulkanDevice& GetActiveDevice();

/*! \brief Return the currently active VulkanDevice
*
* The active device can be set using VulkanDeviceAPI::SetDevice.
* Each CPU thread has its own active device, mimicking the
* semantics of cudaSetDevice.
*/
int GetActiveDeviceID();

/*! \brief Return the VulkanDevice associated with a specific device_id
*
* These are constructed during VulkanDeviceAPI initialization, so
Expand Down Expand Up @@ -113,6 +128,14 @@ class VulkanDeviceAPI final : public DeviceAPI {
* The memory pools must be destructed before devices_.
*/
ThreadMap<WorkspacePool> pool_per_thread;

/*! \brief The index of the active device for each CPU thread.
*
* To mimic the semantics of cudaSetDevice, each CPU thread can set
* the device on which functions should run. If unset, the active
* device defaults to device_id == 0.
*/
ThreadMap<int> active_device_id_per_thread;
};

} // namespace vulkan
Expand Down
40 changes: 0 additions & 40 deletions src/runtime/vulkan/vulkan_thread_entry.cc

This file was deleted.

51 changes: 0 additions & 51 deletions src/runtime/vulkan/vulkan_thread_entry.h

This file was deleted.

4 changes: 1 addition & 3 deletions src/runtime/vulkan/vulkan_wrapped_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

#include "../file_utils.h"
#include "vulkan_device_api.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
Expand All @@ -45,8 +44,7 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr<Object> sptr,

void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
const ArgUnion64* pack_args) const {
int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id;
ICHECK_LT(device_id, kVulkanMaxNumDevice);
int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID();
auto& device = VulkanDeviceAPI::Global()->device(device_id);
if (!scache_[device_id]) {
scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
Expand Down

0 comments on commit 551edf4

Please sign in to comment.