Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,7 @@ static const char *getUrDeviceTarget(const char *URDeviceTarget) {
}

static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
const device &Dev) {
detail::device_impl &DeviceImpl = *detail::getSyclObjImpl(Dev);
const device_impl &DeviceImpl) {
auto &Adapter = DeviceImpl.getAdapter();

const ur_device_handle_t &URDeviceHandle = DeviceImpl.getHandleRef();
Expand Down Expand Up @@ -621,7 +620,7 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
}

bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
RTDeviceBinaryImage *BinImage, const device &Dev) {
RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl) {
// Decide whether a devicelib image should be used.
int Bfloat16DeviceLibVersion = -1;
if (m_Bfloat16DeviceLibImages[0].get() == BinImage)
Expand All @@ -640,7 +639,6 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
// more devicelib images in this way.
enum { DEVICELIB_FALLBACK = 0, DEVICELIB_NATIVE };
ur_bool_t NativeBF16Supported = false;
detail::device_impl &DeviceImpl = *detail::getSyclObjImpl(Dev);
ur_result_t CallSuccessful =
DeviceImpl.getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
DeviceImpl.getHandleRef(),
Expand All @@ -658,15 +656,15 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
return false;
}

static bool checkLinkingSupport(const device &Dev,
static bool checkLinkingSupport(const device_impl &DeviceImpl,
const RTDeviceBinaryImage &Img) {
const char *Target = Img.getRawData().DeviceTargetSpec;
// TODO replace with extension checks once implemented in UR.
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0) {
return true;
}
if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0) {
return Dev.is_gpu() && Dev.get_backend() == backend::opencl;
return DeviceImpl.is_gpu() && DeviceImpl.getBackend() == backend::opencl;
}
return false;
}
Expand Down Expand Up @@ -701,7 +699,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
HandledSymbols.insert(ISProp->Name);
}
ur::DeviceBinaryType Format = MainImg.getFormat();
if (!WorkList.empty() && !checkLinkingSupport(Dev, MainImg))
if (!WorkList.empty() &&
!checkLinkingSupport(*getSyclObjImpl(Dev).get(), MainImg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we changing line 691 instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I didn't want to refactor the doesDevSupportDeviceRequirements function that internally calls the checkDevSupportDeviceRequirements that uses tempalte <typename Param> device::get_info<Param>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll look into it separately.

throw exception(make_error_code(errc::feature_not_supported),
"Cannot resolve external symbols, linking is unsupported "
"for the backend");
Expand All @@ -715,10 +714,10 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
RTDeviceBinaryImage *Img = It->second;
if (Img->getFormat() != Format ||
!doesDevSupportDeviceRequirements(Dev, *Img) ||
!compatibleWithDevice(Img, Dev))
!compatibleWithDevice(Img, *getSyclObjImpl(Dev).get()))
continue;
if (isSpecialDeviceImage(Img) &&
!isSpecialDeviceImageShouldBeUsed(Img, Dev))
!isSpecialDeviceImageShouldBeUsed(Img, *getSyclObjImpl(Dev).get()))
continue;
DeviceImagesToLink.insert(Img);
Found = true;
Expand Down Expand Up @@ -2415,14 +2414,14 @@ kernel_id ProgramManager::getSYCLKernelID(KernelNameStrRefT KernelName) {
"No kernel found with the specified name");
}

bool ProgramManager::hasCompatibleImage(const device &Dev) {
bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
std::lock_guard<std::mutex> Guard(m_KernelIDsMutex);

return std::any_of(
m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(),
[&](std::pair<RTDeviceBinaryImage *,
std::shared_ptr<std::vector<kernel_id>>>
Elem) { return compatibleWithDevice(Elem.first, Dev); });
Elem) { return compatibleWithDevice(Elem.first, DeviceImpl); });
}

std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() {
Expand Down Expand Up @@ -2555,7 +2554,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
RTDeviceBinaryImage *BinImage, const context &Ctx, const device &Dev) {
const bundle_state ImgState = getBinImageState(BinImage);

assert(compatibleWithDevice(BinImage, Dev));
assert(compatibleWithDevice(BinImage, *getSyclObjImpl(Dev).get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, line 2554 would be a better change, IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about L2567 where we are creating vector of devices? I thoughts it also should be refactored, but in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I've been doing during the transition is something like std::vector<device>{createSyclObjFromImpl<device>(Dev.shared_from_this())} but I can follow up on that separately.


std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
// Collect kernel names for the image.
Expand Down Expand Up @@ -2640,7 +2639,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
KernelImageMap.insert({KernelID, {}});

for (RTDeviceBinaryImage *BinImage : BinImages) {
if (!compatibleWithDevice(BinImage, Dev) ||
if (!compatibleWithDevice(BinImage, *getSyclObjImpl(Dev).get()) ||
!doesDevSupportDeviceRequirements(Dev, *BinImage))
continue;

Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class ProgramManager {
const char *UniqueId);

// Returns true if any available image is compatible with the device Dev.
bool hasCompatibleImage(const device &Dev);
bool hasCompatibleImage(const device_impl &DeviceImpl);

// The function gets a device_global entry identified by the pointer to the
// device_global object from the device_global map.
Expand Down Expand Up @@ -406,7 +406,7 @@ class ProgramManager {

bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
const device &Dev);
const device_impl &DeviceImpl);

protected:
/// The three maps below are used during kernel resolution. Any kernel is
Expand Down
7 changes: 3 additions & 4 deletions sycl/source/device_selector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ namespace detail {
// itself, so only matching devices will be scored.
static int getDevicePreference(const device &Device) {
int Score = 0;

const device_impl &DeviceImpl = *getSyclObjImpl(Device).get();
// Strongly prefer devices with available images.
auto &program_manager = sycl::detail::ProgramManager::getInstance();
if (program_manager.hasCompatibleImage(Device))
if (program_manager.hasCompatibleImage(DeviceImpl))
Score += 1000;

// Prefer level_zero backend devices.
if (detail::getSyclObjImpl(Device)->getBackend() ==
backend::ext_oneapi_level_zero)
if (DeviceImpl.getBackend() == backend::ext_oneapi_level_zero)
Score += 50;

return Score;
Expand Down