Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 139 additions & 14 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <iostream>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -71,19 +73,107 @@ CONSTFIX char clGetDeviceFunctionPointerName[] =

#undef CONSTFIX

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

// For the time being, cache is split into multiple maps of type
// `context -> function_type'.
// There's another way. A mapping of context to collection of function pointers.
// Though, the former design allows for simultaneous access for different
// function pointer for different contexts.
template <const char *FuncName, typename FuncT> struct ExtFuncCache {
std::map<pi_context, FuncT> Cache;
// FIXME Use spin-lock to make lock/unlock faster and w/o context switching
std::mutex Mtx;
};

struct ExtFuncCacheCollection;

namespace detail {
template <const char *FuncName, typename FuncT>
ExtFuncCache<FuncName, FuncT> &get(::ExtFuncCacheCollection &);
} // namespace detail

struct ExtFuncCacheCollection {
template <const char *FuncName, typename FuncT>
ExtFuncCache<FuncName, FuncT> &get() {
return detail::get<FuncName, FuncT>(*this);
}

#define DEFINE_INTEL(t_pfx) \
ExtFuncCache<t_pfx##Name, t_pfx##INTEL_fn> t_pfx##_Cache
#define DEFINE(t_pfx) ExtFuncCache<t_pfx##Name, t_pfx##_fn> t_pfx##_Cache

DEFINE_INTEL(clHostMemAlloc);
DEFINE_INTEL(clDeviceMemAlloc);
DEFINE_INTEL(clSharedMemAlloc);
DEFINE_INTEL(clCreateBufferWithProperties);
DEFINE_INTEL(clMemBlockingFree);
DEFINE_INTEL(clMemFree);
DEFINE_INTEL(clSetKernelArgMemPointer);
DEFINE_INTEL(clEnqueueMemset);
DEFINE_INTEL(clEnqueueMemcpy);
DEFINE_INTEL(clGetMemAllocInfo);
DEFINE(clGetDeviceFunctionPointer);
DEFINE(clSetProgramSpecializationConstant);
#undef DEFINE
#undef DEFINE_INTEL
};

namespace detail {
#define DEFINE_GETTER_INTEL(t_pfx) \
template <> \
ExtFuncCache<t_pfx##Name, t_pfx##INTEL_fn> \
&get<t_pfx##Name, t_pfx##INTEL_fn>(::ExtFuncCacheCollection & C) { \
return C.t_pfx##_Cache; \
}
#define DEFINE_GETTER(t_pfx) \
template <> \
ExtFuncCache<t_pfx##Name, t_pfx##_fn> &get<t_pfx##Name, t_pfx##_fn>( \
::ExtFuncCacheCollection & C) { \
return C.t_pfx##_Cache; \
}

DEFINE_GETTER_INTEL(clHostMemAlloc)
DEFINE_GETTER_INTEL(clDeviceMemAlloc)
DEFINE_GETTER_INTEL(clSharedMemAlloc)
DEFINE_GETTER_INTEL(clCreateBufferWithProperties)
DEFINE_GETTER_INTEL(clMemBlockingFree)
DEFINE_GETTER_INTEL(clMemFree)
DEFINE_GETTER_INTEL(clSetKernelArgMemPointer)
DEFINE_GETTER_INTEL(clEnqueueMemset)
DEFINE_GETTER_INTEL(clEnqueueMemcpy)
DEFINE_GETTER_INTEL(clGetMemAllocInfo)
DEFINE_GETTER(clGetDeviceFunctionPointer)
DEFINE_GETTER(clSetProgramSpecializationConstant)
#undef DEFINE_GETTER
#undef DEFINE_GETTER_INTEL
} // namespace detail

ExtFuncCacheCollection *ExtFuncCaches = nullptr;

// USM helper function to get an extension function pointer
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
// TODO
// Potentially redo caching as PI interface changes.
thread_local static std::map<pi_context, T> FuncPtrs;
ExtFuncCache<FuncName, T> &Cache = ExtFuncCaches->get<FuncName, T>();

std::lock_guard<std::mutex> CacheLock{Cache.Mtx};

auto It = Cache.Cache.find(context);

// if cached, return cached FuncPtr
if (auto F = FuncPtrs[context]) {
if (It != Cache.Cache.end()) {
// if cached that extension is not available return nullptr and
// PI_INVALID_VALUE
*fptr = F;
return F ? PI_SUCCESS : PI_INVALID_VALUE;
*fptr = It->second;
return It->second ? PI_SUCCESS : PI_INVALID_VALUE;
}

cl_uint deviceCount;
Expand Down Expand Up @@ -117,12 +207,12 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {

if (!FuncPtr) {
// Cache that the extension is not available
FuncPtrs[context] = nullptr;
Cache.Cache[context] = nullptr;
return PI_INVALID_VALUE;
}

*fptr = FuncPtr;
FuncPtrs[context] = FuncPtr;
Cache.Cache[context] = FuncPtr;

return cast<pi_result>(ret_err);
}
Expand Down Expand Up @@ -561,9 +651,6 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
return false;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);
pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
const char *func_name,
pi_uint64 *function_pointer_ret) {
Expand Down Expand Up @@ -1304,10 +1391,6 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
}
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

pi_result piextProgramSetSpecializationConstant(pi_program prog,
pi_uint32 spec_id,
size_t spec_size,
Expand Down Expand Up @@ -1383,9 +1466,49 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
pi_result piTearDown(void *PluginParameter) {
(void)PluginParameter;
delete ExtFuncCaches;
ExtFuncCaches = nullptr;
return PI_SUCCESS;
}

pi_result piContextRelease(pi_context Context) {
#define RELEASE_EXT_FUNCS_CACHE_INTEL(t_pfx) \
{ \
ExtFuncCache<t_pfx##Name, t_pfx##INTEL_fn> &Cache = \
ExtFuncCaches->get<t_pfx##Name, t_pfx##INTEL_fn>(); \
std::lock_guard<std::mutex> CacheLock{Cache.Mtx}; \
auto It = Cache.Cache.find(Context); \
if (It != Cache.Cache.end()) \
Cache.Cache.erase(It); \
}
#define RELEASE_EXT_FUNCS_CACHE(t_pfx) \
{ \
ExtFuncCache<t_pfx##Name, t_pfx##_fn> &Cache = \
ExtFuncCaches->get<t_pfx##Name, t_pfx##_fn>(); \
std::lock_guard<std::mutex> CacheLock{Cache.Mtx}; \
auto It = Cache.Cache.find(Context); \
if (It != Cache.Cache.end()) \
Cache.Cache.erase(It); \
}

RELEASE_EXT_FUNCS_CACHE_INTEL(clHostMemAlloc);
RELEASE_EXT_FUNCS_CACHE_INTEL(clDeviceMemAlloc);
RELEASE_EXT_FUNCS_CACHE_INTEL(clSharedMemAlloc);
RELEASE_EXT_FUNCS_CACHE_INTEL(clCreateBufferWithProperties);
RELEASE_EXT_FUNCS_CACHE_INTEL(clMemBlockingFree);
RELEASE_EXT_FUNCS_CACHE_INTEL(clMemFree);
RELEASE_EXT_FUNCS_CACHE_INTEL(clSetKernelArgMemPointer);
RELEASE_EXT_FUNCS_CACHE_INTEL(clEnqueueMemset);
RELEASE_EXT_FUNCS_CACHE_INTEL(clEnqueueMemcpy);
RELEASE_EXT_FUNCS_CACHE_INTEL(clGetMemAllocInfo);
RELEASE_EXT_FUNCS_CACHE(clGetDeviceFunctionPointer);
RELEASE_EXT_FUNCS_CACHE(clSetProgramSpecializationConstant);
#undef RELEASE_EXT_FUNCS_CACHE
#undef RELEASE_EXT_FUNCS_CACHE_INTEL

return cast<pi_result>(clReleaseContext(cast<cl_context>(Context)));
}

pi_result piPluginInit(pi_plugin *PluginInit) {
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
if (CompareVersions < 0) {
Expand All @@ -1397,6 +1520,8 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
// PI interface supports higher version or the same version.
strncpy(PluginInit->PluginVersion, SupportedVersion, 4);

ExtFuncCaches = new ExtFuncCacheCollection;

#define _PI_CL(pi_api, ocl_api) \
(PluginInit->PiFunctionTable).pi_api = (decltype(&::pi_api))(&ocl_api);

Expand All @@ -1420,7 +1545,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
_PI_CL(piContextCreate, piContextCreate)
_PI_CL(piContextGetInfo, clGetContextInfo)
_PI_CL(piContextRetain, clRetainContext)
_PI_CL(piContextRelease, clReleaseContext)
_PI_CL(piContextRelease, piContextRelease)
_PI_CL(piextContextGetNativeHandle, piextContextGetNativeHandle)
_PI_CL(piextContextCreateWithNativeHandle, piextContextCreateWithNativeHandle)
// Queue
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_opencl_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# UNSUPPORTED: libcxx

piContextCreate
piContextRelease
piDeviceGetInfo
piDevicesGet
piEnqueueMemBufferMap
Expand Down