diff --git a/sycl/include/CL/sycl/detail/cg.hpp b/sycl/include/CL/sycl/detail/cg.hpp index 685d41bfae2aa..e4d541585edb9 100644 --- a/sycl/include/CL/sycl/detail/cg.hpp +++ b/sycl/include/CL/sycl/detail/cg.hpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/CL/sycl/interop_handle.hpp b/sycl/include/CL/sycl/interop_handle.hpp new file mode 100644 index 0000000000000..3296ab783bebf --- /dev/null +++ b/sycl/include/CL/sycl/interop_handle.hpp @@ -0,0 +1,147 @@ +//==------------ interop_handle.hpp --- SYCL interop handle ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { + +namespace detail { +class AccessorBaseHost; +class ExecCGCommand; +class DispatchHostTask; +class queue_impl; +class device_impl; +class context_impl; +} // namespace detail + +class queue; +class device; +class context; + +class interop_handle { +public: + /// Receives a SYCL accessor that has been defined as a requirement for the + /// command group, and returns the underlying OpenCL memory object that is + /// used by the SYCL runtime. If the accessor passed as parameter is not part + /// of the command group requirements (e.g. it is an unregistered placeholder + /// accessor), the exception `cl::sycl::invalid_object` is thrown + /// asynchronously. + template + typename std::enable_if< + Target != access::target::host_buffer, + typename interop>::type>::type + get_native_mem(const accessor &Acc) const { +#ifndef __SYCL_DEVICE_ONLY__ + const auto *AccBase = static_cast(&Acc); + return getMemImpl( + detail::getSyclObjImpl(*AccBase).get()); +#else + (void)Acc; + // we believe this won't be ever called on device side + return nullptr; +#endif + } + + template + typename std::enable_if< + Target == access::target::host_buffer, + typename interop>::type>::type + get_native_mem(const accessor &) const { + throw invalid_object_error("Getting memory object out of host accessor is " + "not allowed", + PI_INVALID_MEM_OBJECT); + } + + /// Returns an underlying OpenCL queue for the SYCL queue used to submit the + /// command group, or the fallback queue if this command-group is re-trying + /// execution on an OpenCL queue. The OpenCL command queue returned is + /// implementation-defined in cases where the SYCL queue maps to multiple + /// underlying OpenCL objects. It is responsibility of the SYCL runtime to + /// ensure the OpenCL queue returned is in a state that can be used to + /// dispatch work, and that other potential OpenCL command queues associated + /// with the same SYCL command queue are not executing commands while the host + /// task is executing. + template + auto get_native_queue() const noexcept -> + typename interop::type { + return reinterpret_cast::type>( + getNativeQueue()); + } + + /// Returns an underlying OpenCL device associated with the SYCL queue used + /// to submit the command group, or the fallback queue if this command-group + /// is re-trying execution on an OpenCL queue. + template + auto get_native_device() const noexcept -> + typename interop::type { + return reinterpret_cast::type>( + getNativeDevice()); + } + + /// Returns an underlying OpenCL context associated with the SYCL queue used + /// to submit the command group, or the fallback queue if this command-group + /// is re-trying execution on an OpenCL queue. + template + auto get_native_context() const noexcept -> + typename interop::type { + return reinterpret_cast::type>( + getNativeContext()); + } + +private: + using ReqToMem = std::pair; + +public: + // TODO set c-tor private + interop_handle(std::vector MemObjs, + const std::shared_ptr &Queue, + const std::shared_ptr &Device, + const std::shared_ptr &Context) + : MQueue(Queue), MDevice(Device), MContext(Context), + MMemObjs(std::move(MemObjs)) {} + +private: + template + auto getMemImpl(detail::Requirement *Req) const -> + typename interop>::type { + return reinterpret_cast>::type>( + getNativeMem(Req)); + } + + pi_native_handle getNativeMem(detail::Requirement *Req) const; + pi_native_handle getNativeQueue() const; + pi_native_handle getNativeDevice() const; + pi_native_handle getNativeContext() const; + + std::shared_ptr MQueue; + std::shared_ptr MDevice; + std::shared_ptr MContext; + + std::vector MMemObjs; +}; + +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index 1a511fcaf3436..f50eabd501f78 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -143,6 +143,7 @@ set(SYCL_SOURCES "function_pointer.cpp" "half_type.cpp" "handler.cpp" + "interop_handle.cpp" "interop_handler.cpp" "kernel.cpp" "platform.cpp" diff --git a/sycl/source/interop_handle.cpp b/sycl/source/interop_handle.cpp new file mode 100644 index 0000000000000..2f52601e7abdb --- /dev/null +++ b/sycl/source/interop_handle.cpp @@ -0,0 +1,50 @@ +//==------------ interop_handle.cpp --- SYCL interop handle ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { + +pi_native_handle interop_handle::getNativeMem(detail::Requirement *Req) const { + auto Iter = std::find_if(std::begin(MMemObjs), std::end(MMemObjs), + [=](ReqToMem Elem) { return (Elem.first == Req); }); + + if (Iter == std::end(MMemObjs)) { + throw invalid_object_error("Invalid memory object used inside interop", + PI_INVALID_MEM_OBJECT); + } + + auto Plugin = MQueue->getPlugin(); + pi_native_handle Handle; + Plugin.call(Iter->second, + &Handle); + return Handle; +} + +pi_native_handle interop_handle::getNativeDevice() const { + return MDevice->getNative(); +} + +pi_native_handle interop_handle::getNativeContext() const { + return MContext->getNative(); +} + +pi_native_handle interop_handle::getNativeQueue() const { + return MQueue->getNative(); +} + +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl)