Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class KernelProgramCache {
using KernelByNameT = std::map<std::string, KernelWithBuildStateT>;
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;

using KernelFastCacheKeyT =
std::tuple<SerializedObj, OSModuleHandle, RT::PiDevice, std::string,
std::string>;
using KernelFastCacheValT =
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>;
using KernelFastCacheT = std::map<KernelFastCacheKeyT, KernelFastCacheValT>;

~KernelProgramCache();

void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
Expand All @@ -102,13 +109,34 @@ class KernelProgramCache {
BR.MBuildCV.notify_all();
}

template <typename KeyT>
KernelFastCacheValT tryToGetKernelFast(KeyT &&CacheKey) {
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
auto It = MKernelFastCache.find(CacheKey);
if (It != MKernelFastCache.end()) {
return It->second;
}
return std::make_tuple(nullptr, nullptr, nullptr);
}

template <typename KeyT, typename ValT>
void saveKernel(KeyT &&CacheKey, ValT &&CacheVal) {
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
// if no insertion took place, thus some other thread has already inserted
// smth in the cache
MKernelFastCache.emplace(CacheKey, CacheVal);
}

private:
std::mutex MProgramCacheMutex;
std::mutex MKernelsPerProgramCacheMutex;

ProgramCacheT MCachedPrograms;
KernelCacheT MKernelsPerProgramCache;
ContextPtr MParentContext;

std::mutex MKernelFastCacheMutex;
KernelFastCacheT MKernelFastCache;
};
} // namespace detail
} // namespace sycl
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ RT::PiKernel program_impl::get_pi_kernel(const std::string &KernelName) const {
RT::PiKernel Kernel = nullptr;

if (is_cacheable()) {
std::tie(Kernel, std::ignore) =
std::tie(Kernel, std::ignore, std::ignore) =
ProgramManager::getInstance().getOrCreateKernel(
MProgramModuleHandle, get_context(), get_devices()[0], KernelName,
this);
Expand Down
34 changes: 27 additions & 7 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,17 +519,17 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
return BuildResult->Ptr.load();
}

std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(
OSModuleHandle M, const context &Context, const device &Device,
const std::string &KernelName, const program_impl *Prg) {
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>
ProgramManager::getOrCreateKernel(OSModuleHandle M, const context &Context,
const device &Device,
const std::string &KernelName,
const program_impl *Prg) {
if (DbgProgMgr > 0) {
std::cerr << ">>> ProgramManager::getOrCreateKernel(" << M << ", "
<< getRawSyclObjImpl(Context) << ", " << getRawSyclObjImpl(Device)
<< ", " << KernelName << ")\n";
}

RT::PiProgram Program =
getBuiltPIProgram(M, Context, Device, KernelName, Prg);
const ContextImplPtr Ctx = getSyclObjImpl(Context);

using PiKernelT = KernelProgramCache::PiKernelT;
Expand All @@ -538,6 +538,24 @@ std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(

KernelProgramCache &Cache = Ctx->getKernelProgramCache();

std::string CompileOpts, LinkOpts;
SerializedObj SpecConsts;
if (Prg) {
CompileOpts = Prg->get_build_options();
Prg->stableSerializeSpecConstRegistry(SpecConsts);
}
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
const RT::PiDevice PiDevice = detail::getSyclObjImpl(Device)->getHandleRef();

auto key = std::make_tuple(std::move(SpecConsts), M, PiDevice,
CompileOpts + LinkOpts, KernelName);
auto ret_tuple = Cache.tryToGetKernelFast(key);
if (std::get<0>(ret_tuple))
return ret_tuple;

RT::PiProgram Program =
getBuiltPIProgram(M, Context, Device, KernelName, Prg);

auto AcquireF = [](KernelProgramCache &Cache) {
return Cache.acquireKernelsPerProgramCache();
};
Expand All @@ -564,8 +582,10 @@ std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(

auto BuildResult = getOrBuild<PiKernelT, invalid_object_error>(
Cache, KernelName, AcquireF, GetF, BuildF);
return std::make_pair(BuildResult->Ptr.load(),
&(BuildResult->MBuildResultMutex));
auto ret_val = std::make_tuple(BuildResult->Ptr.load(),
&(BuildResult->MBuildResultMutex), Program);
Cache.saveKernel(key, ret_val);
return ret_val;
}

RT::PiProgram
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ProgramManager {
const property_list &PropList,
bool JITCompilationIsRequired = false);

std::pair<RT::PiKernel, std::mutex *>
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>
getOrCreateKernel(OSModuleHandle M, const context &Context,
const device &Device, const std::string &KernelName,
const program_impl *Prg);
Expand Down
7 changes: 2 additions & 5 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ cl_int ExecCGCommand::enqueueImp() {
Program = SyclProg->getHandleRef();
if (SyclProg->is_cacheable()) {
RT::PiKernel FoundKernel = nullptr;
std::tie(FoundKernel, KernelMutex) =
std::tie(FoundKernel, KernelMutex, std::ignore) =
detail::ProgramManager::getInstance().getOrCreateKernel(
ExecKernel->MOSModuleHandle,
ExecKernel->MSyclKernel->get_info<info::kernel::context>(),
Expand All @@ -2020,13 +2020,10 @@ cl_int ExecCGCommand::enqueueImp() {
} else
KnownProgram = false;
} else {
std::tie(Kernel, KernelMutex) =
std::tie(Kernel, KernelMutex, Program) =
detail::ProgramManager::getInstance().getOrCreateKernel(
ExecKernel->MOSModuleHandle, Context, MQueue->get_device(),
ExecKernel->MKernelName, nullptr);
MQueue->getPlugin().call<PiApiKind::piKernelGetInfo>(
Kernel, PI_KERNEL_INFO_PROGRAM, sizeof(RT::PiProgram), &Program,
nullptr);
}

pi_result Error = PI_SUCCESS;
Expand Down