diff --git a/context.cpp b/context.cpp index fe392e36cc..24bcbd6ce7 100644 --- a/context.cpp +++ b/context.cpp @@ -22,15 +22,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( std::unique_ptr ContextPtr{nullptr}; try { - hipCtx_t Current = nullptr; - // Create a scoped context. - hipCtx_t NewContext; - UR_CHECK_ERROR(hipCtxGetCurrent(&Current)); - RetErr = UR_CHECK_ERROR( - hipCtxCreate(&NewContext, hipDeviceMapHost, phDevices[0]->get())); - ContextPtr = std::unique_ptr(new ur_context_handle_t_{ - ur_context_handle_t_::kind::UserDefined, NewContext, *phDevices}); + ContextPtr = std::unique_ptr( + new ur_context_handle_t_{*phDevices}); static std::once_flag InitFlag; std::call_once( @@ -43,14 +37,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( }, RetErr); - // For non-primary scoped contexts keep the last active on top of the stack - // as `hipCtxCreate` replaces it implicitly otherwise. - // Primary contexts are kept on top of the stack, so the previous context - // is not queried and therefore not recovered. - if (Current != nullptr) { - UR_CHECK_ERROR(hipCtxSetCurrent(Current)); - } - *phContext = ContextPtr.release(); } catch (ur_result_t Err) { RetErr = Err; @@ -97,40 +83,10 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (hContext->decrementReferenceCount() > 0) { - return UR_RESULT_SUCCESS; - } - hContext->invokeExtendedDeleters(); - - std::unique_ptr context{hContext}; - - if (!hContext->isPrimary()) { - hipCtx_t HIPCtxt = hContext->get(); - // hipCtxSynchronize is not supported for AMD platform so we can just - // destroy the context, for NVIDIA make sure it's synchronized. -#if defined(__HIP_PLATFORM_NVIDIA__) - hipCtx_t Current = nullptr; - UR_CHECK_ERROR(hipCtxGetCurrent(&Current)); - if (HIPCtxt != Current) { - UR_CHECK_ERROR(hipCtxPushCurrent(HIPCtxt)); - } - UR_CHECK_ERROR(hipCtxSynchronize()); - UR_CHECK_ERROR(hipCtxGetCurrent(&Current)); - if (HIPCtxt == Current) { - UR_CHECK_ERROR(hipCtxPopCurrent(&Current)); - } -#endif - return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt)); - } else { - // Primary context is not destroyed, but released - hipDevice_t HIPDev = hContext->getDevice()->get(); - hipCtx_t Current; - UR_CHECK_ERROR(hipCtxPopCurrent(&Current)); - return UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDev)); + if (hContext->decrementReferenceCount() == 0) { + delete hContext; } - - hipCtx_t HIPCtxt = hContext->get(); - return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt)); + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL @@ -143,7 +99,8 @@ urContextRetain(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle( ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) { - *phNativeContext = reinterpret_cast(hContext->get()); + *phNativeContext = reinterpret_cast( + hContext->getDevice()->getNativeContext()); return UR_RESULT_SUCCESS; } diff --git a/context.hpp b/context.hpp index aa61e1e84b..f504bb01ce 100644 --- a/context.hpp +++ b/context.hpp @@ -62,14 +62,11 @@ struct ur_context_handle_t_ { using native_type = hipCtx_t; - enum class kind { Primary, UserDefined } Kind; - native_type HIPContext; ur_device_handle_t DeviceId; std::atomic_uint32_t RefCount; - ur_context_handle_t_(kind K, hipCtx_t Ctxt, ur_device_handle_t DevId) - : Kind{K}, HIPContext{Ctxt}, DeviceId{DevId}, RefCount{1} { - DeviceId->setContext(this); + ur_context_handle_t_(ur_device_handle_t DevId) + : DeviceId{DevId}, RefCount{1} { urDeviceRetain(DeviceId); }; @@ -90,10 +87,6 @@ struct ur_context_handle_t_ { ur_device_handle_t getDevice() const noexcept { return DeviceId; } - native_type get() const noexcept { return HIPContext; } - - bool isPrimary() const noexcept { return Kind == kind::Primary; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } uint32_t decrementReferenceCount() noexcept { return --RefCount; } @@ -113,19 +106,18 @@ namespace { /// API is the one active on the thread. /// The implementation tries to avoid replacing the hipCtx_t if it cans class ScopedContext { - ur_context_handle_t PlacedContext; hipCtx_t Original; bool NeedToRecover; public: - ScopedContext(ur_context_handle_t Ctxt) - : PlacedContext{Ctxt}, NeedToRecover{false} { + ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} { - if (!PlacedContext) { - throw UR_RESULT_ERROR_INVALID_CONTEXT; + if (!hDevice) { + throw UR_RESULT_ERROR_INVALID_DEVICE; } - hipCtx_t Desired = PlacedContext->get(); + // FIXME when multi device context are supported in HIP adapter + hipCtx_t Desired = hDevice->getNativeContext(); UR_CHECK_ERROR(hipCtxGetCurrent(&Original)); if (Original != Desired) { // Sets the desired context as the active one for the thread diff --git a/device.cpp b/device.cpp index f66ed8439a..c5acc35254 100644 --- a/device.cpp +++ b/device.cpp @@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice, return UR_RESULT_SUCCESS; ur_event_handle_t_::native_type Event; - ScopedContext Active(hDevice->getContext()); + ScopedContext Active(hDevice); if (pDeviceTimestamp) { UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault)); diff --git a/device.hpp b/device.hpp index 9a56652957..1488d0e61c 100644 --- a/device.hpp +++ b/device.hpp @@ -22,11 +22,17 @@ struct ur_device_handle_t_ { native_type HIPDevice; std::atomic_uint32_t RefCount; ur_platform_handle_t Platform; - ur_context_handle_t Context; + hipCtx_t HIPContext; public: - ur_device_handle_t_(native_type HipDevice, ur_platform_handle_t Platform) - : HIPDevice(HipDevice), RefCount{1}, Platform(Platform) {} + ur_device_handle_t_(native_type HipDevice, hipCtx_t Context, + ur_platform_handle_t Platform) + : HIPDevice(HipDevice), RefCount{1}, Platform(Platform), + HIPContext(Context) {} + + ~ur_device_handle_t_() { + UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice)); + } native_type get() const noexcept { return HIPDevice; }; @@ -34,9 +40,7 @@ struct ur_device_handle_t_ { ur_platform_handle_t getPlatform() const noexcept { return Platform; }; - void setContext(ur_context_handle_t Ctxt) { Context = Ctxt; }; - - ur_context_handle_t getContext() { return Context; }; + hipCtx_t getNativeContext() { return HIPContext; }; }; int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute); diff --git a/enqueue.cpp b/enqueue.cpp index 81afa92358..1b0b2acc2a 100644 --- a/enqueue.cpp +++ b/enqueue.cpp @@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue, return UR_RESULT_SUCCESS; } try { - ScopedContext Active(CommandQueue->getContext()); + ScopedContext Active(CommandQueue->getDevice()); auto Result = forLatestEvents( EventWaitList, NumEventsInWaitList, @@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -252,7 +252,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); uint32_t StreamToken; ur_stream_quard Guard; @@ -363,7 +363,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier( ur_result_t Result; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); uint32_t StreamToken; ur_stream_quard Guard; hipStream_t HIPStream = hQueue->getNextComputeStream( @@ -513,7 +513,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, @@ -561,7 +561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -609,7 +609,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); ur_result_t Result; auto Stream = hQueue->getNextTransferStream(); @@ -656,7 +656,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -751,7 +751,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); auto Stream = hQueue->getNextTransferStream(); ur_result_t Result; @@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); if (phEventWaitList) { @@ -954,7 +954,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); if (phEventWaitList) { @@ -1020,7 +1020,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); if (phEventWaitList) { Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, @@ -1116,7 +1116,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( HostPtr, numEventsInWaitList, phEventWaitList, phEvent); } else { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); if (IsPinned) { Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList, @@ -1167,7 +1167,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList, phEventWaitList, phEvent); } else { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); if (IsPinned) { Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList, @@ -1198,7 +1198,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( std::unique_ptr EventPtr{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); uint32_t StreamToken; ur_stream_quard Guard; hipStream_t HIPStream = hQueue->getNextComputeStream( @@ -1256,7 +1256,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( std::unique_ptr EventPtr{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -1287,6 +1287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { +#if HIP_VERSION_MAJOR >= 5 void *HIPDevicePtr = const_cast(pMem); unsigned int PointerRangeSize = 0; UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize, @@ -1301,7 +1302,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( std::unique_ptr EventPtr{nullptr}; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -1311,8 +1312,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( UR_COMMAND_USM_PREFETCH, hQueue, HIPStream)); EventPtr->start(); } - Result = UR_CHECK_ERROR(hipMemPrefetchAsync( - pMem, size, hQueue->getContext()->getDevice()->get(), HIPStream)); + Result = UR_CHECK_ERROR( + hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream)); if (phEvent) { Result = EventPtr->record(); *phEvent = EventPtr.release(); @@ -1322,11 +1323,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( } return Result; +#else + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +#endif } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, ur_usm_advice_flags_t, ur_event_handle_t *phEvent) { +#if HIP_VERSION_MAJOR >= 5 void *HIPDevicePtr = const_cast(pMem); unsigned int PointerRangeSize = 0; UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize, @@ -1337,6 +1342,9 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, // TODO implement a mapping to hipMemAdvise once the expected behaviour // of urEnqueueUSMAdvise is detailed in the USM extension return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent); +#else + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +#endif } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D( @@ -1367,7 +1375,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getDevice()); hipStream_t HIPStream = hQueue->getNextTransferStream(); Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); diff --git a/event.cpp b/event.cpp index 93faf2def0..7ae684a860 100644 --- a/event.cpp +++ b/event.cpp @@ -178,7 +178,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { try { auto Context = phEventWaitList[0]->getContext(); - ScopedContext Active(Context); + ScopedContext Active(Context->getDevice()); auto WaitFunc = [Context](ur_event_handle_t Event) -> ur_result_t { UR_ASSERT(Event, UR_RESULT_ERROR_INVALID_EVENT); @@ -277,7 +277,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { std::unique_ptr event_ptr{hEvent}; ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { - ScopedContext Active(hEvent->getContext()); + ScopedContext Active(hEvent->getContext()->getDevice()); Result = hEvent->release(); } catch (...) { Result = UR_RESULT_ERROR_OUT_OF_RESOURCES; diff --git a/kernel.cpp b/kernel.cpp index 709657ab0c..8da2d969c2 100644 --- a/kernel.cpp +++ b/kernel.cpp @@ -17,7 +17,7 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName, std::unique_ptr RetKernel{nullptr}; try { - ScopedContext Active(hProgram->getContext()); + ScopedContext Active(hProgram->getContext()->getDevice()); hipFunction_t HIPFunc; Result = UR_CHECK_ERROR( diff --git a/memory.cpp b/memory.cpp index 3401b5beff..5fdf1a98da 100644 --- a/memory.cpp +++ b/memory.cpp @@ -22,7 +22,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { return UR_RESULT_SUCCESS; } - ScopedContext Active(uniqueMemObj->getContext()); + ScopedContext Active(uniqueMemObj->getContext()->getDevice()); if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) { switch (uniqueMemObj->Mem.BufferMem.MemAllocMode) { @@ -93,7 +93,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( ur_mem_handle_t RetMemObj = nullptr; try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); void *Ptr; auto pHost = pProperties ? pProperties->pHost : nullptr; ur_mem_handle_t_::MemImpl::BufferMem::AllocMode AllocMode = @@ -210,7 +210,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( std::unique_ptr RetMemObj{nullptr}; try { - ScopedContext Active(Context); + ScopedContext Active(Context->getDevice()); RetMemObj = std::unique_ptr{new ur_mem_handle_t_{ Context, hBuffer, flags, AllocMode, Ptr, HostPtr, pRegion->size}}; @@ -239,7 +239,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, UrReturnHelper ReturnValue(propSize, pMemInfo, pPropSizeRet); - ScopedContext Active(hMemory->getContext()); + ScopedContext Active(hMemory->getContext()->getDevice()); switch (MemInfoType) { case UR_MEM_INFO_SIZE: { @@ -417,7 +417,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate( size_t ImageSizeBytes = PixelSizeBytes * pImageDesc->width * pImageDesc->height * pImageDesc->depth; - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); hipArray *ImageArray; Result = UR_CHECK_ERROR(hipArray3DCreate( reinterpret_cast(&ImageArray), &ArrayDesc)); diff --git a/platform.cpp b/platform.cpp index 11f8fc55d4..f3244da988 100644 --- a/platform.cpp +++ b/platform.cpp @@ -82,8 +82,10 @@ urPlatformGet(uint32_t NumEntries, ur_platform_handle_t *phPlatforms, for (int i = 0; i < NumDevices; ++i) { hipDevice_t Device; Err = UR_CHECK_ERROR(hipDeviceGet(&Device, i)); + hipCtx_t Context; + Err = UR_CHECK_ERROR(hipDevicePrimaryCtxRetain(&Context, Device)); PlatformIds[i].Devices.emplace_back( - new ur_device_handle_t_{Device, &PlatformIds[i]}); + new ur_device_handle_t_{Device, Context, &PlatformIds[i]}); } } catch (const std::bad_alloc &) { // Signal out-of-memory situation diff --git a/program.cpp b/program.cpp index a66c444c4d..46ee394fec 100644 --- a/program.cpp +++ b/program.cpp @@ -103,7 +103,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t, ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hProgram->getContext()); + ScopedContext Active(hProgram->getContext()->getDevice()); hProgram->buildProgram(pOptions); @@ -209,7 +209,7 @@ urProgramRelease(ur_program_handle_t hProgram) { ur_result_t Result = UR_RESULT_ERROR_INVALID_PROGRAM; try { - ScopedContext Active(hProgram->getContext()); + ScopedContext Active(hProgram->getContext()->getDevice()); auto HIPModule = hProgram->get(); Result = UR_CHECK_ERROR(hipModuleUnload(HIPModule)); } catch (...) { diff --git a/queue.cpp b/queue.cpp index 19447bcf8a..c047db922a 100644 --- a/queue.cpp +++ b/queue.cpp @@ -193,7 +193,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { try { std::unique_ptr QueueImpl(hQueue); - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getContext()->getDevice()); hQueue->forEachStream([](hipStream_t S) { UR_CHECK_ERROR(hipStreamSynchronize(S)); @@ -214,7 +214,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueFinish(ur_queue_handle_t hQueue) { try { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getContext()->getDevice()); hQueue->syncStreams([&Result](hipStream_t S) { Result = UR_CHECK_ERROR(hipStreamSynchronize(S)); @@ -245,7 +245,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueFlush(ur_queue_handle_t) { UR_APIEXPORT ur_result_t UR_APICALL urQueueGetNativeHandle(ur_queue_handle_t hQueue, ur_queue_native_desc_t *, ur_native_handle_t *phNativeQueue) { - ScopedContext Active(hQueue->getContext()); + ScopedContext Active(hQueue->getContext()->getDevice()); *phNativeQueue = reinterpret_cast(hQueue->getNextComputeStream()); return UR_RESULT_SUCCESS; diff --git a/usm.cpp b/usm.cpp index 73b9633906..03a4ff18d7 100644 --- a/usm.cpp +++ b/usm.cpp @@ -24,7 +24,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipHostMalloc(ppMem, size)); } catch (ur_result_t Error) { Result = Error; @@ -49,7 +49,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipMalloc(ppMem, size)); } catch (ur_result_t Error) { Result = Error; @@ -74,7 +74,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipMallocManaged(ppMem, size, hipMemAttachGlobal)); } catch (ur_result_t Error) { Result = Error; @@ -93,7 +93,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, void *pMem) { ur_result_t Result = UR_RESULT_SUCCESS; try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); unsigned int Type; hipPointerAttribute_t hipPointerAttributeType; Result = @@ -123,7 +123,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, UrReturnHelper ReturnValue(propValueSize, pPropValue, pPropValueSizeRet); try { - ScopedContext Active(hContext); + ScopedContext Active(hContext->getDevice()); switch (propName) { case UR_USM_ALLOC_INFO_TYPE: { unsigned int Value;