Skip to content

Commit

Permalink
[HIP][UR] Use primary context in HIP adapter (#10514)
Browse files Browse the repository at this point in the history
The primary context has been default for a while in CUDA PI/Adapter. See
intel/llvm#8197.

This PR brings the HIP adapter up to speed.

It also changes the scoped context to only take a `ur_device_handle_t`
since this is coupled with a native primary context in HIP
  • Loading branch information
hdelan authored and szadam committed Oct 13, 2023
1 parent 4aca651 commit 94e2324
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 112 deletions.
57 changes: 7 additions & 50 deletions context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
hipCtx_t Current = nullptr;

// Create a scoped context.
hipCtx_t NewContext;
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
RetErr = UR_CHECK_ERROR(
hipCtxCreate(&NewContext, hipDeviceMapHost, phDevices[0]->get()));
ContextPtr = std::unique_ptr<ur_context_handle_t_>(new ur_context_handle_t_{
ur_context_handle_t_::kind::UserDefined, NewContext, *phDevices});
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});

static std::once_flag InitFlag;
std::call_once(
Expand All @@ -43,14 +37,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
},
RetErr);

// For non-primary scoped contexts keep the last active on top of the stack
// as `hipCtxCreate` replaces it implicitly otherwise.
// Primary contexts are kept on top of the stack, so the previous context
// is not queried and therefore not recovered.
if (Current != nullptr) {
UR_CHECK_ERROR(hipCtxSetCurrent(Current));
}

*phContext = ContextPtr.release();
} catch (ur_result_t Err) {
RetErr = Err;
Expand Down Expand Up @@ -97,40 +83,10 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (hContext->decrementReferenceCount() > 0) {
return UR_RESULT_SUCCESS;
}
hContext->invokeExtendedDeleters();

std::unique_ptr<ur_context_handle_t_> context{hContext};

if (!hContext->isPrimary()) {
hipCtx_t HIPCtxt = hContext->get();
// hipCtxSynchronize is not supported for AMD platform so we can just
// destroy the context, for NVIDIA make sure it's synchronized.
#if defined(__HIP_PLATFORM_NVIDIA__)
hipCtx_t Current = nullptr;
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
if (HIPCtxt != Current) {
UR_CHECK_ERROR(hipCtxPushCurrent(HIPCtxt));
}
UR_CHECK_ERROR(hipCtxSynchronize());
UR_CHECK_ERROR(hipCtxGetCurrent(&Current));
if (HIPCtxt == Current) {
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
}
#endif
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
} else {
// Primary context is not destroyed, but released
hipDevice_t HIPDev = hContext->getDevice()->get();
hipCtx_t Current;
UR_CHECK_ERROR(hipCtxPopCurrent(&Current));
return UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDev));
if (hContext->decrementReferenceCount() == 0) {
delete hContext;
}

hipCtx_t HIPCtxt = hContext->get();
return UR_CHECK_ERROR(hipCtxDestroy(HIPCtxt));
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
Expand All @@ -143,7 +99,8 @@ urContextRetain(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
*phNativeContext = reinterpret_cast<ur_native_handle_t>(hContext->get());
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevice()->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
22 changes: 7 additions & 15 deletions context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ struct ur_context_handle_t_ {

using native_type = hipCtx_t;

enum class kind { Primary, UserDefined } Kind;
native_type HIPContext;
ur_device_handle_t DeviceId;
std::atomic_uint32_t RefCount;

ur_context_handle_t_(kind K, hipCtx_t Ctxt, ur_device_handle_t DevId)
: Kind{K}, HIPContext{Ctxt}, DeviceId{DevId}, RefCount{1} {
DeviceId->setContext(this);
ur_context_handle_t_(ur_device_handle_t DevId)
: DeviceId{DevId}, RefCount{1} {
urDeviceRetain(DeviceId);
};

Expand All @@ -90,10 +87,6 @@ struct ur_context_handle_t_ {

ur_device_handle_t getDevice() const noexcept { return DeviceId; }

native_type get() const noexcept { return HIPContext; }

bool isPrimary() const noexcept { return Kind == kind::Primary; }

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }
Expand All @@ -113,19 +106,18 @@ namespace {
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it cans
class ScopedContext {
ur_context_handle_t PlacedContext;
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedContext(ur_context_handle_t Ctxt)
: PlacedContext{Ctxt}, NeedToRecover{false} {
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!PlacedContext) {
throw UR_RESULT_ERROR_INVALID_CONTEXT;
if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

hipCtx_t Desired = PlacedContext->get();
// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
Expand Down
2 changes: 1 addition & 1 deletion device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;

ur_event_handle_t_::native_type Event;
ScopedContext Active(hDevice->getContext());
ScopedContext Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
Expand Down
16 changes: 10 additions & 6 deletions device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ struct ur_device_handle_t_ {
native_type HIPDevice;
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
ur_context_handle_t Context;
hipCtx_t HIPContext;

public:
ur_device_handle_t_(native_type HipDevice, ur_platform_handle_t Platform)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform) {}
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
ur_platform_handle_t Platform)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context) {}

~ur_device_handle_t_() {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
}

native_type get() const noexcept { return HIPDevice; };

uint32_t getReferenceCount() const noexcept { return RefCount; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

void setContext(ur_context_handle_t Ctxt) { Context = Ctxt; };

ur_context_handle_t getContext() { return Context; };
hipCtx_t getNativeContext() { return HIPContext; };
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
50 changes: 29 additions & 21 deletions enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t CommandQueue,
return UR_RESULT_SUCCESS;
}
try {
ScopedContext Active(CommandQueue->getContext());
ScopedContext Active(CommandQueue->getDevice());

auto Result = forLatestEvents(
EventWaitList, NumEventsInWaitList,
Expand Down Expand Up @@ -97,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -143,7 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -252,7 +252,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

uint32_t StreamToken;
ur_stream_quard Guard;
Expand Down Expand Up @@ -363,7 +363,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
ur_result_t Result;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -513,7 +513,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -561,7 +561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -609,7 +609,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
ur_result_t Result;
auto Stream = hQueue->getNextTransferStream();

Expand Down Expand Up @@ -656,7 +656,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -751,7 +751,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

auto Stream = hQueue->getNextTransferStream();
ur_result_t Result;
Expand Down Expand Up @@ -892,7 +892,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -954,7 +954,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();

if (phEventWaitList) {
Expand Down Expand Up @@ -1020,7 +1020,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
if (phEventWaitList) {
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
Expand Down Expand Up @@ -1116,7 +1116,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
HostPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1167,7 +1167,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
hMem->Mem.BufferMem.getMapSize(), pMappedPtr, numEventsInWaitList,
phEventWaitList, phEvent);
} else {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());

if (IsPinned) {
Result = urEnqueueEventsWait(hQueue, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1198,7 +1198,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
uint32_t StreamToken;
ur_stream_quard Guard;
hipStream_t HIPStream = hQueue->getNextComputeStream(
Expand Down Expand Up @@ -1256,7 +1256,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down Expand Up @@ -1287,6 +1287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
#if HIP_VERSION_MAJOR >= 5
void *HIPDevicePtr = const_cast<void *>(pMem);
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
Expand All @@ -1301,7 +1302,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand All @@ -1311,8 +1312,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
UR_COMMAND_USM_PREFETCH, hQueue, HIPStream));
EventPtr->start();
}
Result = UR_CHECK_ERROR(hipMemPrefetchAsync(
pMem, size, hQueue->getContext()->getDevice()->get(), HIPStream));
Result = UR_CHECK_ERROR(
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
if (phEvent) {
Result = EventPtr->record();
*phEvent = EventPtr.release();
Expand All @@ -1322,11 +1323,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
}

return Result;
#else
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
}

UR_APIEXPORT ur_result_t UR_APICALL
urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
ur_usm_advice_flags_t, ur_event_handle_t *phEvent) {
#if HIP_VERSION_MAJOR >= 5
void *HIPDevicePtr = const_cast<void *>(pMem);
unsigned int PointerRangeSize = 0;
UR_CHECK_ERROR(hipPointerGetAttribute(&PointerRangeSize,
Expand All @@ -1337,6 +1342,9 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
// TODO implement a mapping to hipMemAdvise once the expected behaviour
// of urEnqueueUSMAdvise is detailed in the USM extension
return urEnqueueEventsWait(hQueue, 0, nullptr, phEvent);
#else
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D(
Expand Down Expand Up @@ -1367,7 +1375,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
ur_result_t Result = UR_RESULT_SUCCESS;

try {
ScopedContext Active(hQueue->getContext());
ScopedContext Active(hQueue->getDevice());
hipStream_t HIPStream = hQueue->getNextTransferStream();
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
phEventWaitList);
Expand Down
Loading

0 comments on commit 94e2324

Please sign in to comment.