Skip to content

Commit

Permalink
[RUNTIME] Add device query for AMD GcnArch (#4341)
Browse files Browse the repository at this point in the history
* add gcnArch query

* kGcnArch query for cuda is a no-op
  • Loading branch information
petrex authored and masahi committed Nov 15, 2019
1 parent 1e2c525 commit 0235d28
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ enum DeviceAttrKind : int {
kDeviceName = 5,
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8
kMaxThreadDimensions = 8,
kGcnArch = 9
};

/*! \brief Number of bytes each allocation must align to */
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
TVMRetValue val;
api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val);
if (val.operator int() == 1) {
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val);
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kGcnArch, &val);
return val.operator int();
}
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
case kGcnArch: return;
}
*rv = value;
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kExist: break;
case kGcnArch: return;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ void OpenCLWorkspace::GetAttr(
*rv = ss.str();
break;
}
case kGcnArch: return;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/runtime/opengl/opengl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ void OpenGLWorkspace::GetAttr(
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kGcnArch: return;
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <hsa/hsa.h>
#include <tvm/runtime/registry.h>
#include "../../../include/tvm/runtime/device_api.h"
#include "rocm_common.h"

namespace tvm {
Expand Down Expand Up @@ -62,16 +63,17 @@ class ROCMDeviceAPI final : public DeviceAPI {
break;
}
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: {
case kComputeVersion:
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kGcnArch: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
*rv = prop.gcnArch;
return;
}
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
*rv = value;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
break;
case kMaxThreadDimensions:
break;
case kGcnArch:
return;
}
}

Expand Down

0 comments on commit 0235d28

Please sign in to comment.