Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][UR][HIP] Hip adapter multi device ctx #11105

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions sycl/plugins/unified_runtime/ur/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *, ur_context_handle_t *phContext) {
std::ignore = DeviceCount;
assert(DeviceCount == 1);
ur_result_t RetErr = UR_RESULT_SUCCESS;

std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
try {
// Create a scoped context.
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
new ur_context_handle_t_{*phDevices});
new ur_context_handle_t_{phDevices, DeviceCount});

static std::once_flag InitFlag;
std::call_once(
Expand Down Expand Up @@ -78,7 +76,7 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(1);
case UR_CONTEXT_INFO_DEVICES:
return ReturnValue(hContext->getDevice());
return ReturnValue(hContext->getDevices());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES:
Expand Down Expand Up @@ -121,8 +119,10 @@ urContextRetain(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
// FIXME this only returns the native context of the first device in the
// SYCL context. This entry point should be deprecated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*phNativeContext = reinterpret_cast<ur_native_handle_t>(
hContext->getDevice()->getNativeContext());
hContext->getDevices()[0]->getNativeContext());
return UR_RESULT_SUCCESS;
}

Expand Down
106 changes: 41 additions & 65 deletions sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
/// with a given device and control access to said device from the user side.
/// UR API context are objects that are passed to functions, and not bound
/// to threads.
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
/// holds the HIP context data. The RAII object \ref ScopedContext implements
/// the active context behavior.
///
/// <b> Primary vs UserDefined context </b>
/// Since the ur_context_handle_t can contain multiple devices, and a `hipCtx_t`
/// refers to only a single device, the `hipCtx_t` is more tightly coupled to a
/// ur_device_handle_t than a ur_context_handle_t. In order to remove some
/// ambiguities about the different semantics of ur_context_handle_t s and
/// native `hipCtx_t`, we access the native `hipCtx_t` solely through the
/// ur_device_handle_t class, by using the RAII object \ref ScopedDevice, which
/// sets the active device (by setting the active native `hipCtx_t`).
///
/// HIP has two different types of context, the Primary context,
/// which is usable by all threads on a given process for a given device, and
/// the aforementioned custom contexts.
/// The HIP documentation, and performance analysis, suggest using the Primary
/// context whenever possible. The Primary context is also used by the HIP
/// Runtime API. For UR applications to interop with HIP Runtime API, they have
/// to use the primary context - and make that active in the thread. The
/// `ur_context_handle_t_` object can be constructed with a `kind` parameter
/// that allows to construct a Primary or `UserDefined` context, so that
/// the UR object interface is always the same.
/// <b> Primary vs User-defined `hipCtx_t` </b>
///
/// HIP has two different types of `hipCtx_t`, the Primary context, which is
/// usable by all threads on a given process for a given device, and the
/// aforementioned custom `hipCtx_t`s.
/// The HIP documentation, confirmed with performance analysis, suggest using
/// the Primary context whenever possible. The Primary context is also used by
/// the HIP Runtime API. For UR applications to interop with HIP Runtime API,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually true for HIP as well or just copied over from CUDA? Is there a HIP runtime API and a HIP driver API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I think this is not relevant to HIP. Will remove

/// they have to use the primary context - and make that active in the thread.
///
/// <b> Destructor callback </b>
///
Expand All @@ -56,6 +58,15 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
/// See proposal for details.
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
///
/// <b> Memory Management for Devices in a Context <\b>
///
/// A ur_buffer_ is associated with a ur_context_handle_t_, which may refer to
/// multiple devices. Therefore the ur_buffer_ must handle a native allocation
/// for each device in the context. UR is responsible for automatically
/// handling event dependencies for kernels writing to or reading from the
/// same ur_buffer_ and migrating memory between native allocations for
/// devices in the same ur_context_handle_t_ if necessary.
///
struct ur_context_handle_t_ {

struct deleter_data {
Expand All @@ -67,15 +78,23 @@ struct ur_context_handle_t_ {

using native_type = hipCtx_t;

ur_device_handle_t DeviceId;
std::vector<ur_device_handle_t> Devices;
uint32_t NumDevices;

std::atomic_uint32_t RefCount;

ur_context_handle_t_(ur_device_handle_t DevId)
: DeviceId{DevId}, RefCount{1} {
urDeviceRetain(DeviceId);
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, NumDevices{NumDevices}, RefCount{1} {
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
}
};

~ur_context_handle_t_() { urDeviceRelease(DeviceId); }
~ur_context_handle_t_() {
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}

void invokeExtendedDeleters() {
std::lock_guard<std::mutex> Guard(Mutex);
Expand All @@ -90,7 +109,9 @@ struct ur_context_handle_t_ {
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
}

ur_device_handle_t getDevice() const noexcept { return DeviceId; }
std::vector<ur_device_handle_t> getDevices() const noexcept {
return Devices;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

Expand Down Expand Up @@ -161,48 +182,3 @@ struct ur_context_handle_t_ {
std::unordered_map<const void *, size_t> USMMappings;
std::set<ur_usm_pool_handle_t> PoolHandles;
};

namespace {
/// RAII type to guarantee recovering original HIP context
/// Scoped context is used across all UR HIP plugin implementation
/// to activate the UR Context on the current thread, matching the
/// HIP driver semantics where the context used for the HIP Driver
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it cans
class ScopedContext {
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedContext(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
if (Original == nullptr) {
// No context is installed on the current thread
// This is the most common case. We can activate the context in the
// thread and leave it there until all the UR context referring to the
// same underlying HIP context are destroyed. This emulates
// the behaviour of the HIP runtime api, and avoids costly context
// switches. No action is required on this side of the if.
} else {
NeedToRecover = true;
}
}
}

~ScopedContext() {
if (NeedToRecover) {
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
}
}
};
} // namespace
2 changes: 1 addition & 1 deletion sycl/plugins/unified_runtime/ur/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
return UR_RESULT_SUCCESS;

ur_event_handle_t_::native_type Event;
ScopedContext Active(hDevice);
ScopedDevice Active(hDevice);

if (pDeviceTimestamp) {
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
Expand Down
58 changes: 56 additions & 2 deletions sycl/plugins/unified_runtime/ur/adapters/hip/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ struct ur_device_handle_t_ {
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
hipCtx_t HIPContext;
size_t DeviceIndex; // The index of the device in the UR context

public:
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
ur_platform_handle_t Platform)
ur_platform_handle_t Platform, size_t DeviceIndex)
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
HIPContext(Context) {}
HIPContext(Context), DeviceIndex(DeviceIndex) {}

~ur_device_handle_t_() {
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
Expand All @@ -41,6 +42,59 @@ struct ur_device_handle_t_ {
ur_platform_handle_t getPlatform() const noexcept { return Platform; };

hipCtx_t getNativeContext() { return HIPContext; };

// Returns the index of the device in question relative to the other devices
// in the platform
size_t getIndex() { return DeviceIndex; }
};

int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

namespace {
/// RAII type to guarantee recovering original HIP device. In UR the
/// `ScopedDevice` sets the active device by using the native underlying
/// `hipCtx_t`. Since a UR context can contain multiple devices, whereas a
/// `hipCtx_t` refers to a single device, it is semantically clearer to access
/// the `hipCtx_t` through the UR device rather than the UR context.
/// Scoped device is used across all UR HIP plugin implementation
/// to activate the UR Device on the current thread, matching the
/// HIP driver semantics where the context used for the HIP Driver
/// API is the one active on the thread.
/// The implementation tries to avoid replacing the hipCtx_t if it can
class ScopedDevice {
hipCtx_t Original;
bool NeedToRecover;

public:
ScopedDevice(ur_device_handle_t hDevice) : NeedToRecover{false} {

if (!hDevice) {
throw UR_RESULT_ERROR_INVALID_DEVICE;
}

// FIXME when multi device context are supported in HIP adapter
hipCtx_t Desired = hDevice->getNativeContext();
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
if (Original != Desired) {
// Sets the desired context as the active one for the thread
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
if (Original == nullptr) {
// No context is installed on the current thread
// This is the most common case. We can activate the context in the
// thread and leave it there until all the UR context referring to the
// same underlying HIP context are destroyed. This emulates
// the behaviour of the HIP runtime api, and avoids costly context
// switches. No action is required on this side of the if.
} else {
NeedToRecover = true;
}
}
}

~ScopedDevice() {
if (NeedToRecover) {
UR_CHECK_ERROR(hipCtxSetCurrent(Original));
}
}
};
} // namespace
Loading