Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HIP][UR] Use primary context in HIP adapter #10514

Merged
merged 4 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 7 additions & 50 deletions sycl/plugins/unified_runtime/ur/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
hipCtx_t Current = nullptr;

// Create a scoped context.
hipCtx_t NewContext;
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
RetErr = UR_CHECK_ERROR(
hipCtxCreate(&NewContext, hipDeviceMapHost, phDevices[0]->get()));
ContextPtr = std::unique_ptr<ur_context_handle_t_>(new ur_context_handle_t_{
ur_context_handle_t_::kind::UserDefined, NewContext, *phDevices});
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});

static std::once_flag InitFlag;
std::call_once(
Expand All @@ -43,14 +37,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
},
RetErr);

// For non-primary scoped contexts keep the last active on top of the stack
// as `hipCtxCreate` replaces it implicitly otherwise.
// Primary contexts are kept on top of the stack, so the previous context
// is not queried and therefore not recovered.
if (Current != nullptr) {
UR_CHECK_ERROR(hipCtxSetCurrent(Current));
}

*phContext = ContextPtr.release();
} catch (ur_result_t Err) {
RetErr = Err;
Expand Down Expand Up @@ -97,40 +83,10 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (hContext->decrementReferenceCount() > 0) {
return UR_RESULT_SUCCESS;
}
hContext->invokeExtendedDeleters();

std::unique_ptr<ur_context_handle_t_> context{hContext};

if (!hContext->isPrimary()) {
hipCtx_t HIPCtxt = hContext->get();
// hipCtxSynchronize is not supported for AMD platform so we can just
// destroy the context, for NVIDIA make sure it's synchronized.
#if defined(__HIP_PLATFORM_NVIDIA__)
hipCtx_t Current = nullptr;
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
if (HIPCtxt != Current) {
UR_CHECK_ERROR(hipCtxPushCurrent(HIPCtxt));
}
UR_CHECK_ERROR(hipCtxSynchronize());
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
if (HIPCtxt == Current) {
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
}
#endif
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
} else {
// Primary context is not destroyed, but released
hipDevice_t HIPDev = hContext->getDevice()->get();
hipCtx_t Current;
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
return UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDev));
if (hContext->decrementReferenceCount() == 0) {
delete hContext;
}

hipCtx_t HIPCtxt = hContext->get();
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
Expand All @@ -143,7 +99,8 @@ urContextRetain(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevice()->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
22 changes: 7 additions & 15 deletions sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ struct ur_context_handle_t_ {

using native_type = hipCtx_t;

enum class kind { Primary, UserDefined } Kind;
native_type HIPContext;
ur_device_handle_t DeviceId;
std::atomic_uint32_t RefCount;

ur_context_handle_t_(kind K, hipCtx_t Ctxt, ur_device_handle_t DevId)
: Kind{K}, HIPContext{Ctxt}, DeviceId{DevId}, RefCount{1} {
DeviceId->setContext(this);
ur_context_handle_t_(ur_device_handle_t DevId)
: DeviceId{DevId}, RefCount{1} {
urDeviceRetain(DeviceId);
};

Expand All @@ -90,10 +87,6 @@ struct ur_context_handle_t_ {

ur_device_handle_t getDevice() const noexcept { return DeviceId; }

native_type get() const noexcept { return HIPContext; }

bool isPrimary() const noexcept { return Kind == kind::Primary; }

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }
Expand All @@ -113,19 +106,18 @@ namespace {
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it cans
class ScopedContext {
ur_context_handle_t PlacedContext;
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedContext(ur_context_handle_t Ctxt)
: PlacedContext{Ctxt}, NeedToRecover{false} {
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!PlacedContext) {
throw UR_RESULT_ERROR_INVALID_CONTEXT;
if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

hipCtx_t Desired = PlacedContext->get();
// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
Expand Down
2 changes: 1 addition & 1 deletion sycl/plugins/unified_runtime/ur/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;

ur_event_handle_t_::native_type Event;
ScopedContext Active(hDevice->getContext());
ScopedContext Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
Expand Down
16 changes: 10 additions & 6 deletions sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ struct ur_device_handle_t_ {
native_type HIPDevice;
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
ur_context_handle_t Context;
hipCtx_t HIPContext;

public:
ur_device_handle_t_(native_type HipDevice, ur_platform_handle_t Platform)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform) {}
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
ur_platform_handle_t Platform)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context) {}

~ur_device_handle_t_() {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
}

native_type get() const noexcept { return HIPDevice; };

uint32_t getReferenceCount() const noexcept { return RefCount; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

void setContext(ur_context_handle_t Ctxt) { Context = Ctxt; };

ur_context_handle_t getContext() { return Context; };
hipCtx_t getNativeContext() { return HIPContext; };
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
50 changes: 29 additions & 21 deletions sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue,
return UR_RESULT_SUCCESS;
}
try {
ScopedContext Active(CommandQueue->getContext());
ScopedContext Active(CommandQueue->getDevice());

auto Result = forLatestEvents(
EventWaitList, NumEventsInWaitList,
Expand Down Expand Up @@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -252,7 +252,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

uint32_t StreamToken;
ur_stream_quard Guard;
Expand Down Expand Up @@ -363,7 +363,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
ur_result_t Result;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -513,7 +513,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -561,7 +561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -609,7 +609,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
ur_result_t Result;
auto Stream = hQueue->getNextTransferStream();

Expand Down Expand Up @@ -656,7 +656,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -751,7 +751,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

auto Stream = hQueue->getNextTransferStream();
ur_result_t Result;
Expand Down Expand Up @@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -954,7 +954,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -1020,7 +1020,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -1116,7 +1116,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
HostPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1167,7 +1167,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1198,7 +1198,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -1256,7 +1256,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1287,6 +1287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
#if HIP_VERSION_MAJOR >= 5
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also adding this check since hipPointerGetAttribute is not supported in HIP versions < 5

void *HIPDevicePtr = const_cast<void *>(pMem);
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
Expand All @@ -1301,7 +1302,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand All @@ -1311,8 +1312,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
EventPtr->start();
}
Result = UR_CHECK_ERROR(hipMemPrefetchAsync(
pMem, size, hQueue->getContext()->getDevice()->get(), HIPStream));
Result = UR_CHECK_ERROR(
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
if (phEvent) {
Result = EventPtr->record();
*phEvent = EventPtr.release();
Expand All @@ -1322,11 +1323,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
}

return Result;
#else
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
}

UR_APIEXPORT ur_result_t UR_APICALL
urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_advice_flags_t, ur_event_handle_t *phEvent) {
#if HIP_VERSION_MAJOR >= 5
void *HIPDevicePtr = const_cast<void *>(pMem);
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
Expand All @@ -1337,6 +1342,9 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
// TODO implement a mapping to hipMemAdvise once the expected behaviour
// of urEnqueueUSMAdvise is detailed in the USM extension
return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent);
#else
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
Expand Down Expand Up @@ -1367,7 +1375,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down
Loading