Skip to content

Commit

Permalink
Start replacing setErrorMessage with a helper class.
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongreig committed Nov 26, 2024
1 parent 21406b2 commit c9b3996
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 97 deletions.
4 changes: 2 additions & 2 deletions source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
std::ignore = pError;
*ppMessage = ErrorMessage;
return ErrorMessageCode;
*ppMessage = MessageHandler.getErrorMessage();
return MessageHandler.getErrorMessageCode();
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
Expand Down
19 changes: 2 additions & 17 deletions source/adapters/cuda/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,7 @@ void detail::ur::assertion(bool Condition, const char *Message) {
}

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
thread_local char ErrorMessage[MaxMessageSize]{};

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
ur_result_t ErrorCode) {
assert(strlen(pMessage) < MaxMessageSize);
// Copy at most MaxMessageSize - 1 bytes to ensure the resultant string is
// always null terminated.
#if defined(_WIN32)
strncpy_s(ErrorMessage, MaxMessageSize - 1, pMessage, strlen(pMessage));
#else
strncpy(ErrorMessage, pMessage, MaxMessageSize - 1);
#endif
ErrorMessageCode = ErrorCode;
}
thread_local ur::MessageHandler<256> MessageHandler;

void setPluginSpecificMessage(CUresult cu_res) {
const char *error_string;
Expand All @@ -130,6 +115,6 @@ void setPluginSpecificMessage(CUresult cu_res) {
strcat(message, "\n");
strcat(message, error_string);

setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage(message, UR_RESULT_ERROR_ADAPTER_SPECIFIC);
free(message);
}
9 changes: 2 additions & 7 deletions source/adapters/cuda/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <cuda.h>
#include <ur/ur.hpp>
#include <ur_util.hpp>

ur_result_t mapErrorUR(CUresult Result);

Expand All @@ -32,13 +33,7 @@ void checkErrorUR(ur_result_t Result, const char *Function, int Line,

std::string getCudaVersionString();

constexpr size_t MaxMessageSize = 256;
extern thread_local ur_result_t ErrorMessageCode;
extern thread_local char ErrorMessage[MaxMessageSize];

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
ur_result_t ErrorCode);
extern thread_local ur::MessageHandler<256> MessageHandler;

void setPluginSpecificMessage(CUresult cu_res);

Expand Down
50 changes: 28 additions & 22 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,

for (auto &UnmappedFlag : UnmappedMemAdviceFlags) {
if (URAdviceFlags & UnmappedFlag) {
setErrorMessage("Memory advice ignored because the CUDA backend does not "
"support some of the specified flags",
UR_RESULT_SUCCESS);
MessageHandler.setErrorMessage(
"Memory advice ignored because the CUDA backend does not "
"support some of the specified flags",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
}
Expand Down Expand Up @@ -283,8 +284,9 @@ setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
}

if (LocalSize > static_cast<uint32_t>(Device->getMaxCapacityLocalMem())) {
setErrorMessage("Excessive allocation of local memory on the device",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage(
"Excessive allocation of local memory on the device",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

Expand All @@ -293,17 +295,18 @@ setKernelParams([[maybe_unused]] const ur_context_handle_t Context,
if (Device->getMaxChosenLocalMem() < 0) {
bool EnvVarHasURPrefix =
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr;
setErrorMessage(EnvVarHasURPrefix ? "Invalid value specified for "
"UR_CUDA_MAX_LOCAL_MEM_SIZE"
: "Invalid value specified for "
"SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage(EnvVarHasURPrefix
? "Invalid value specified for "
"UR_CUDA_MAX_LOCAL_MEM_SIZE"
: "Invalid value specified for "
"SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
if (LocalSize > static_cast<uint32_t>(Device->getMaxChosenLocalMem())) {
bool EnvVarHasURPrefix =
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr;
setErrorMessage(
MessageHandler.setErrorMessage(
EnvVarHasURPrefix
? "Local memory for kernel exceeds the amount requested using "
"UR_CUDA_MAX_LOCAL_MEM_SIZE. Try increasing the value of "
Expand Down Expand Up @@ -686,8 +689,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
return UR_RESULT_SUCCESS;
#else
[[maybe_unused]] auto _ = launchPropList;
setErrorMessage("This feature requires cuda 11.8 or later.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage("This feature requires cuda 11.8 or later.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
#endif // CUDA_VERSION >= 11080
}
Expand Down Expand Up @@ -1616,18 +1619,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
// for managed memory. Therefore, ignore prefetch hint if concurrent managed
// memory access is not available.
if (!getAttribute(Device, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)) {
setErrorMessage("Prefetch hint ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
MessageHandler.setErrorMessage(
"Prefetch hint ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

unsigned int IsManaged;
UR_CHECK_ERROR(cuPointerGetAttribute(
&IsManaged, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)pMem));
if (!IsManaged) {
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
UR_RESULT_SUCCESS);
MessageHandler.setErrorMessage(
"Prefetch hint ignored as prefetch only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

Expand Down Expand Up @@ -1678,9 +1683,10 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
(advice & UR_USM_ADVICE_FLAG_DEFAULT)) {
ur_device_handle_t Device = hQueue->getDevice();
if (!getAttribute(Device, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)) {
setErrorMessage("Mem advise ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
MessageHandler.setErrorMessage(
"Mem advise ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

Expand All @@ -1693,7 +1699,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
UR_CHECK_ERROR(cuPointerGetAttribute(
&IsManaged, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)pMem));
if (!IsManaged) {
setErrorMessage(
MessageHandler.setErrorMessage(
"Memory advice ignored as memory advices only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
Expand Down
9 changes: 5 additions & 4 deletions source/adapters/cuda/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
// If the runtime owns the native handle, we have reference to the queue.
// Otherwise, the event handle comes from an interop API with no RT refs.
if (!hEvent->getQueue()) {
setErrorMessage("Command queue info cannot be queried for the event. The "
"event object was created from a native event and has no "
"valid reference to a command queue.",
UR_RESULT_ERROR_INVALID_VALUE);
MessageHandler.setErrorMessage(
"Command queue info cannot be queried for the event. The "
"event object was created from a native event and has no "
"valid reference to a command queue.",
UR_RESULT_ERROR_INVALID_VALUE);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
return ReturnValue(hEvent->getQueue());
Expand Down
7 changes: 4 additions & 3 deletions source/adapters/cuda/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ ur_result_t urTextureCreate(ur_sampler_handle_t hSampler,
#if CUDA_VERSION >= 11060
ImageTexDesc.flags |= CU_TRSF_SEAMLESS_CUBEMAP;
#else
setErrorMessage("The UR_EXP_SAMPLER_CUBEMAP_FILTER_MODE_SEAMLESS "
"feature requires cuda 11.6 or later.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage(
"The UR_EXP_SAMPLER_CUBEMAP_FILTER_MODE_SEAMLESS "
"feature requires cuda 11.6 or later.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
#endif
}
Expand Down
7 changes: 4 additions & 3 deletions source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,10 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
arrayDesc.Format != CU_AD_FORMAT_SIGNED_INT32 &&
arrayDesc.Format != CU_AD_FORMAT_HALF &&
arrayDesc.Format != CU_AD_FORMAT_FLOAT) {
setErrorMessage("PI CUDA kernels only support images with channel "
"types int32, uint32, float, and half.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
MessageHandler.setErrorMessage(
"PI CUDA kernels only support images with channel "
"types int32, uint32, float, and half.",
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
CUsurfObject CuSurf =
Expand Down
6 changes: 2 additions & 4 deletions source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,8 @@ template <> zes_structure_type_t getZesStructureType<zes_mem_properties_t>() {
return ZES_STRUCTURE_TYPE_MEM_PROPERTIES;
}

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
thread_local char ErrorMessage[MaxMessageSize]{};
thread_local int32_t ErrorAdapterNativeCode;
// Global variable for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur::MessageHandler<256> MessageHandler;

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
Expand Down
12 changes: 2 additions & 10 deletions source/adapters/level_zero/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,7 @@ constexpr char ZE_SUPPORTED_EXTENSIONS[] =
"cl_khr_il_program cl_khr_subgroups cl_intel_subgroups "
"cl_intel_subgroups_short cl_intel_required_subgroup_size ";

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
constexpr size_t MaxMessageSize = 256;
extern thread_local ur_result_t ErrorMessageCode;
extern thread_local char ErrorMessage[MaxMessageSize];
extern thread_local int32_t ErrorAdapterNativeCode;

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *pMessage,
ur_result_t ErrorCode,
int32_t AdapterErrorCode);
// Global variable for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
extern thread_local ur::MessageHandler<256> MessageHandler;

#define L0_DRIVER_INORDER_MIN_VERSION 29534
4 changes: 2 additions & 2 deletions source/adapters/opencl/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
ur_adapter_handle_t, const char **ppMessage, int32_t *pError) {
*ppMessage = cl_adapter::ErrorMessage;
*pError = cl_adapter::ErrorMessageCode;
*ppMessage = cl_adapter::MessageHandler.getErrorMessage();
*pError = cl_adapter::MessageHandler.getErrorMessageCode();

return UR_RESULT_SUCCESS;
}
Expand Down
18 changes: 2 additions & 16 deletions source/adapters/opencl/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,9 @@
#include "logger/ur_logger.hpp"
namespace cl_adapter {

/* Global variables for urAdapterGetLastError() */
thread_local int32_t ErrorMessageCode = 0;
thread_local char ErrorMessage[MaxMessageSize]{};
/* Global variable for urAdapterGetLastError() */
thread_local ur::MessageHandler<256> MessageHandler;

[[maybe_unused]] void setErrorMessage(const char *Message, int32_t ErrorCode) {
assert(strlen(Message) < cl_adapter::MaxMessageSize);
// Copy at most MaxMessageSize - 1 bytes to ensure the resultant string is
// always null terminated.
#if defined(_WIN32)
strncpy_s(cl_adapter::ErrorMessage, MaxMessageSize - 1, Message,
strlen(Message));
#else
strncpy(cl_adapter::ErrorMessage, Message, MaxMessageSize - 1);
#endif

ErrorMessageCode = ErrorCode;
}
} // namespace cl_adapter

ur_result_t mapCLErrorToUR(cl_int Result) {
Expand Down
9 changes: 2 additions & 7 deletions source/adapters/opencl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <map>
#include <mutex>
#include <ur/ur.hpp>
#include <ur_util.hpp>

/**
* Call an OpenCL API and, if the result is not CL_SUCCESS, automatically map
Expand Down Expand Up @@ -149,13 +150,7 @@ inline const OpenCLVersion V3_0(3, 0);
} // namespace oclv

namespace cl_adapter {
constexpr size_t MaxMessageSize = 256;
extern thread_local int32_t ErrorMessageCode;
extern thread_local char ErrorMessage[MaxMessageSize];

// Utility function for setting a message and warning
[[maybe_unused]] void setErrorMessage(const char *Message,
ur_result_t ErrorCode);
extern ur::MessageHandler<256> MessageHandler;

[[noreturn]] void die(const char *Message);

Expand Down
20 changes: 20 additions & 0 deletions source/common/ur_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ur_api.h>

#include <atomic>
#include <cassert>
#include <iostream>
#include <map>
#include <optional>
Expand Down Expand Up @@ -501,4 +502,23 @@ static inline std::string groupDigits(Numeric numeric) {

template <typename T> Spinlock<Rc<T>> AtomicSingleton<T>::instance;

namespace ur {
template <size_t MaxMessageSize> class MessageHandler {
ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
char ErrorMessage[MaxMessageSize]{};

public:
void setErrorMessage(const char *pMessage, ur_result_t ErrorCode) {
assert(strlen(pMessage) < MaxMessageSize);
// Copy at most MaxMessageSize - 1 bytes to ensure the resultant string is
// always null terminated.
strncpy(ErrorMessage, pMessage, MaxMessageSize - 1);
ErrorMessageCode = ErrorCode;
}

const char *getErrorMessage() const { return ErrorMessage; }

ur_result_t getErrorMessageCode() const { return ErrorMessageCode; }
};
} // namespace ur
#endif /* UR_UTIL_H */

0 comments on commit c9b3996

Please sign in to comment.