-
Notifications
You must be signed in to change notification settings - Fork 796
[SYCL] Use refs to device_impl in ProgramManager #18320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -583,8 +583,7 @@ static const char *getUrDeviceTarget(const char *URDeviceTarget) { | |
| } | ||
|
|
||
| static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage, | ||
| const device &Dev) { | ||
| detail::device_impl &DeviceImpl = *detail::getSyclObjImpl(Dev); | ||
| const device_impl &DeviceImpl) { | ||
| auto &Adapter = DeviceImpl.getAdapter(); | ||
|
|
||
| const ur_device_handle_t &URDeviceHandle = DeviceImpl.getHandleRef(); | ||
|
|
@@ -621,7 +620,7 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) { | |
| } | ||
|
|
||
| bool ProgramManager::isSpecialDeviceImageShouldBeUsed( | ||
| RTDeviceBinaryImage *BinImage, const device &Dev) { | ||
| RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl) { | ||
| // Decide whether a devicelib image should be used. | ||
| int Bfloat16DeviceLibVersion = -1; | ||
| if (m_Bfloat16DeviceLibImages[0].get() == BinImage) | ||
|
|
@@ -640,7 +639,6 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed( | |
| // more devicelib images in this way. | ||
| enum { DEVICELIB_FALLBACK = 0, DEVICELIB_NATIVE }; | ||
| ur_bool_t NativeBF16Supported = false; | ||
| detail::device_impl &DeviceImpl = *detail::getSyclObjImpl(Dev); | ||
| ur_result_t CallSuccessful = | ||
| DeviceImpl.getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>( | ||
| DeviceImpl.getHandleRef(), | ||
|
|
@@ -658,15 +656,15 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed( | |
| return false; | ||
| } | ||
|
|
||
| static bool checkLinkingSupport(const device &Dev, | ||
| static bool checkLinkingSupport(const device_impl &DeviceImpl, | ||
| const RTDeviceBinaryImage &Img) { | ||
| const char *Target = Img.getRawData().DeviceTargetSpec; | ||
| // TODO replace with extension checks once implemented in UR. | ||
| if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0) { | ||
| return true; | ||
| } | ||
| if (strcmp(Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0) { | ||
| return Dev.is_gpu() && Dev.get_backend() == backend::opencl; | ||
| return DeviceImpl.is_gpu() && DeviceImpl.getBackend() == backend::opencl; | ||
| } | ||
| return false; | ||
| } | ||
|
|
@@ -701,7 +699,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols( | |
| HandledSymbols.insert(ISProp->Name); | ||
| } | ||
| ur::DeviceBinaryType Format = MainImg.getFormat(); | ||
| if (!WorkList.empty() && !checkLinkingSupport(Dev, MainImg)) | ||
| if (!WorkList.empty() && | ||
| !checkLinkingSupport(*getSyclObjImpl(Dev).get(), MainImg)) | ||
| throw exception(make_error_code(errc::feature_not_supported), | ||
| "Cannot resolve external symbols, linking is unsupported " | ||
| "for the backend"); | ||
|
|
@@ -715,10 +714,10 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols( | |
| RTDeviceBinaryImage *Img = It->second; | ||
| if (Img->getFormat() != Format || | ||
| !doesDevSupportDeviceRequirements(Dev, *Img) || | ||
| !compatibleWithDevice(Img, Dev)) | ||
| !compatibleWithDevice(Img, *getSyclObjImpl(Dev).get())) | ||
| continue; | ||
| if (isSpecialDeviceImage(Img) && | ||
| !isSpecialDeviceImageShouldBeUsed(Img, Dev)) | ||
| !isSpecialDeviceImageShouldBeUsed(Img, *getSyclObjImpl(Dev).get())) | ||
| continue; | ||
| DeviceImagesToLink.insert(Img); | ||
| Found = true; | ||
|
|
@@ -2415,14 +2414,14 @@ kernel_id ProgramManager::getSYCLKernelID(KernelNameStrRefT KernelName) { | |
| "No kernel found with the specified name"); | ||
| } | ||
|
|
||
| bool ProgramManager::hasCompatibleImage(const device &Dev) { | ||
| bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) { | ||
| std::lock_guard<std::mutex> Guard(m_KernelIDsMutex); | ||
|
|
||
| return std::any_of( | ||
| m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(), | ||
| [&](std::pair<RTDeviceBinaryImage *, | ||
| std::shared_ptr<std::vector<kernel_id>>> | ||
| Elem) { return compatibleWithDevice(Elem.first, Dev); }); | ||
| Elem) { return compatibleWithDevice(Elem.first, DeviceImpl); }); | ||
| } | ||
|
|
||
| std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() { | ||
|
|
@@ -2555,7 +2554,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage( | |
| RTDeviceBinaryImage *BinImage, const context &Ctx, const device &Dev) { | ||
| const bundle_state ImgState = getBinImageState(BinImage); | ||
|
|
||
| assert(compatibleWithDevice(BinImage, Dev)); | ||
| assert(compatibleWithDevice(BinImage, *getSyclObjImpl(Dev).get())); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, line 2554 would be a better change, IMO.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about L2567 where we are creating vector of devices? I thoughts it also should be refactored, but in a separate PR?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I've been doing during the transition is something like |
||
|
|
||
| std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs; | ||
| // Collect kernel names for the image. | ||
|
|
@@ -2640,7 +2639,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState( | |
| KernelImageMap.insert({KernelID, {}}); | ||
|
|
||
| for (RTDeviceBinaryImage *BinImage : BinImages) { | ||
| if (!compatibleWithDevice(BinImage, Dev) || | ||
| if (!compatibleWithDevice(BinImage, *getSyclObjImpl(Dev).get()) || | ||
| !doesDevSupportDeviceRequirements(Dev, *BinImage)) | ||
| continue; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we changing line 691 instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I didn't want to refactor the
doesDevSupportDeviceRequirementsfunction that internally calls thecheckDevSupportDeviceRequirementsthat usestempalte <typename Param> device::get_info<Param>There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll look into it separately.