Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cublas] Fix race condition in cublas handle deletion #136

Merged
merged 3 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/oneapi/mkl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@ enum class order : char {
} //namespace mkl
} //namespace oneapi

// Workaround for supporting ::half for hipSYCL
// Workaround for supporting ::half for hipSYCL and DPC++
// TODO: This should be removed after the interface is SYCL2020 conformant
#ifdef __HIPSYCL__
using ::cl::sycl::half;
#endif

#endif //_ONEMKL_TYPES_HPP_
33 changes: 19 additions & 14 deletions src/blas/backends/cublas/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,33 @@ namespace mkl {
namespace blas {
namespace cublas {

template<typename T>
template <typename T>
struct cublas_handle {
using handle_container_t = std::unordered_map<T, std::atomic<cublasHandle_t> *>;
handle_container_t cublas_handle_mapper_{};
~cublas_handle() noexcept(false){
for (auto &handle_pair : cublas_handle_mapper_) {
cublasStatus_t err;
if (handle_pair.second != nullptr) {
auto handle = handle_pair.second->exchange(nullptr);
if (handle != nullptr) {
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
handle = nullptr;
~cublas_handle() noexcept(false) {
for (auto &handle_pair : cublas_handle_mapper_) {
cublasStatus_t err;
if (handle_pair.second != nullptr) {
auto handle = handle_pair.second->exchange(nullptr);
if (handle != nullptr) {
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already
// destroyed by the ContextCallback and we're free to delete the
// atomic object.
delete handle_pair.second;
}

handle_pair.second = nullptr;
}
delete handle_pair.second;
handle_pair.second = nullptr;
}
cublas_handle_mapper_.clear();
}
cublas_handle_mapper_.clear();
}
};


} // namespace cublas
} // namespace blas
} // namespace mkl
Expand Down
27 changes: 14 additions & 13 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,21 @@ CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) {
}

void ContextCallback(void *userData) {
auto *ptr = static_cast<std::atomic<cublasHandle_t> **>(userData);
auto *ptr = static_cast<std::atomic<cublasHandle_t> *>(userData);
if (!ptr) {
return;
}
if (*ptr != nullptr) {
auto handle = (*ptr)->exchange(nullptr);
if (handle != nullptr) {
cublasStatus_t err1;
CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle);
handle = nullptr;
}
delete *ptr;
*ptr = nullptr;
auto handle = ptr->exchange(nullptr);
if (handle != nullptr) {
cublasStatus_t err1;
CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle);
handle = nullptr;
}
else {
// if the handle is nullptr it means the handle was already destroyed by
// the cublas_handle destructor and we're free to delete the atomic
// object.
delete ptr;
}
}

Expand Down Expand Up @@ -113,9 +115,8 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const cl::sycl::queue &que
auto insert_iter = handle_helper.cublas_handle_mapper_.insert(
std::make_pair(piPlacedContext_, new std::atomic<cublasHandle_t>(handle)));

auto ptr = &(insert_iter.first->second);

sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback, ptr);
sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback,
insert_iter.first->second);

return handle;
}
Expand Down