diff --git a/source/adapters/hip/device.cpp b/source/adapters/hip/device.cpp index 76bfefaa33..76505738da 100644 --- a/source/adapters/hip/device.cpp +++ b/source/adapters/hip/device.cpp @@ -58,11 +58,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue(VendorId); } case UR_DEVICE_INFO_MAX_COMPUTE_UNITS: { - int ComputeUnits = 0; - UR_CHECK_ERROR(hipDeviceGetAttribute( - &ComputeUnits, hipDeviceAttributeMultiprocessorCount, hDevice->get())); - detail::ur::assertion(ComputeUnits >= 0); - return ReturnValue(static_cast(ComputeUnits)); + return ReturnValue(hDevice->getNumComputeUnits()); } case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: { return ReturnValue(MaxWorkItemDimensions); diff --git a/source/adapters/hip/device.hpp b/source/adapters/hip/device.hpp index bd2b6002e0..fcd50f6614 100644 --- a/source/adapters/hip/device.hpp +++ b/source/adapters/hip/device.hpp @@ -26,6 +26,7 @@ struct ur_device_handle_t_ { ur_platform_handle_t Platform; hipEvent_t EvBase; // HIP event used as base counter uint32_t DeviceIndex; + uint32_t NumComputeUnits{0}; int MaxWorkGroupSize{0}; int MaxBlockDimX{0}; @@ -41,6 +42,9 @@ struct ur_device_handle_t_ { : HIPDevice(HipDevice), RefCount{1}, Platform(Platform), EvBase(EvBase), DeviceIndex(DeviceIndex) { + UR_CHECK_ERROR(hipDeviceGetAttribute( + reinterpret_cast(&NumComputeUnits), + hipDeviceAttributeMultiprocessorCount, HIPDevice)); UR_CHECK_ERROR(hipDeviceGetAttribute( &MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice)); UR_CHECK_ERROR(hipDeviceGetAttribute( @@ -85,6 +89,8 @@ struct ur_device_handle_t_ { int getManagedMemSupport() const noexcept { return ManagedMemSupport; }; + uint32_t getNumComputeUnits() const noexcept { return NumComputeUnits; }; + int getConcurrentManagedAccess() const noexcept { return ConcurrentManagedAccess; }; diff --git a/source/adapters/hip/kernel.cpp b/source/adapters/hip/kernel.cpp index a5aefb1293..47ded4597d 100644 --- a/source/adapters/hip/kernel.cpp +++ b/source/adapters/hip/kernel.cpp @@ -161,24 +161,57 @@ urKernelRelease(ur_kernel_handle_t hKernel) { return UR_RESULT_SUCCESS; } -// TODO(ur): Not implemented on hip atm. Also, need to add tests for this -// feature. -UR_APIEXPORT ur_result_t UR_APICALL -urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; -} - UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim, const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) { - std::ignore = hKernel; + UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL); + std::ignore = hDevice; - std::ignore = workDim; - std::ignore = pLocalWorkSize; - std::ignore = dynamicSharedMemorySize; - std::ignore = pGroupCountRet; - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + size_t localWorkSize = pLocalWorkSize[0]; + localWorkSize *= (workDim >= 2 ? pLocalWorkSize[1] : 1); + localWorkSize *= (workDim == 3 ? pLocalWorkSize[2] : 1); + + // We need to set the active current device for this kernel explicitly here, + // because the occupancy querying API does not take device parameter. + ur_device_handle_t Device = hKernel->getProgram()->getDevice(); + ScopedDevice Active(Device); + try { + // We need to calculate max num of work-groups using per-device semantics. + + int MaxNumActiveGroupsPerCU{0}; + UR_CHECK_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize, + dynamicSharedMemorySize)); + detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0); + // Handle the case where we can't have all work-group processors (WGPs) + // active with at least 1 group per WGP. In that case, the device is still + // able to run 1 work-group, hence we will manually check if it is possible + // with the available HW resources. + if (MaxNumActiveGroupsPerCU == 0) { + size_t MaxWorkGroupSize{}; + urKernelGetGroupInfo( + hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE, + sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr); + size_t MaxLocalSizeBytes{}; + urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE, + sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr); + if (localWorkSize > MaxWorkGroupSize || + dynamicSharedMemorySize > MaxLocalSizeBytes) + *pGroupCountRet = 0; + else + *pGroupCountRet = 1; + } else { + // Multiply by the number of WGPs (CUs = compute units) on the device in + // order to retrieve the total number of groups/blocks that can be + // launched. + *pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU; + } + } catch (ur_result_t Err) { + return Err; + } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(