Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zelReloadDrivers(flags) API #191

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/loader_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ There are currently 3 versioned components assigned the following name strings:
- `"validation layer"`
- `"loader"`

### zelReloadDrivers

Close, reload, and re-initialize through zeInit all driver libraries currently loaded.

- __flags__ init flags that will be passed to each driver's implementation of zeInit, it should match what was previously provided at the first zeInit.


### zelLoaderTranslateHandle

Expand Down
4 changes: 4 additions & 0 deletions include/loader/ze_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ zelLoaderGetVersions(
size_t *num_elems, //Pointer to num versions to get.
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned

ZE_APIEXPORT ze_result_t ZE_APICALL
zelReloadDrivers(
ze_init_flags_t flags); //Init flags, should match flags used in zeInit

typedef enum _zel_handle_type_t {
ZEL_HANDLE_DRIVER,
ZEL_HANDLE_DEVICE,
Expand Down
16 changes: 16 additions & 0 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ zelLoaderGetVersions(
#endif
}

ze_result_t ZE_APICALL
zelReloadDrivers(
ze_init_flags_t flags)
{
#ifdef DYNAMIC_LOAD_LOADER
if(nullptr == ze_lib::context->loader)
return ZE_RESULT_ERROR;
typedef ze_result_t (ZE_APICALL *zelReloadDriver_t)(ze_driver_handle_t hDriver);
auto reloadDrivers = reinterpret_cast<zelReloadDriver_t>(
GET_FUNCTION_PTR(ze_lib::context->loader, "zelReloadDriversInternal") );
return reloadDrivers(flags);
#else
return zelReloadDriversInternal(flags);
#endif
}


ze_result_t ZE_APICALL
zelLoaderTranslateHandle(
Expand Down
252 changes: 252 additions & 0 deletions source/loader/ze_loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,258 @@ zelLoaderGetVersionsInternal(
return ZE_RESULT_SUCCESS;
}

ZE_DLLEXPORT ze_result_t ZE_APICALL
zelReloadDriversInternal(
ze_init_flags_t flags)
{
for( auto& drv : loader::context->zeDrivers ) {
if(drv.initStatus != ZE_RESULT_SUCCESS)
continue;

if (drv.handle) {
auto free_result = FREE_DRIVER_LIBRARY( drv.handle );
auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result);
if (failure)
return ZE_RESULT_ERROR_UNINITIALIZED;
}
Comment on lines +84 to +89
Copy link
Contributor

@rwmcguir rwmcguir Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So overall I would ask if there is any advantage to separating this step out into it's own API of zelUnloadDriversInternal() at least initially and possibly even into user exposed API in the future. I think fine for doing a proof of concept... Down the road this would give the option of changing which drivers were re-initialized for whatever reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2nd thought here, would be we could simplify the load step.... by removing unload... But in contrast we could then improve the unload step by adding state machine checking if any was necessary... i.e. are there any conditions in which unload was not allowed, or if a timeout/wait was necessary to kill in flight commands before catastrophic teardown.. (I realize this doesn't really need to be enforced here, as we could put that responsibility on the user and effective document 'use at your own risk')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the road this would give the option of changing which drivers were re-initialized for whatever reason.

I think in order to have that we would need it to operate only on a single driver, rather than separating load/unload.

are there any conditions in which unload was not allowed

I agree it's on the user to not call this if they expect in-flight things to finish, so I don't think we need to count that as one such condition. If we end up adopting this, we should document that all prior resources / handles are invalidated and that new ones will need to be obtained.


drv.handle = LOAD_DRIVER_LIBRARY( drv.name.c_str() );
if (NULL == drv.handle)
return ZE_RESULT_ERROR_UNINITIALIZED;

auto zeGetGlobalProcAddrTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") );
if (!zeGetGlobalProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetGlobalProcAddrTableResult = zeGetGlobalProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Global);
if (zeGetGlobalProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetGlobalProcAddrTableResult;

auto zeGetRTASBuilderExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASBuilderExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") );
if (!zeGetRTASBuilderExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetRTASBuilderExpProcAddrTableResult = zeGetRTASBuilderExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASBuilderExp);
if (zeGetRTASBuilderExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetRTASBuilderExpProcAddrTableResult;

auto zeGetRTASParallelOperationExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASParallelOperationExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") );
if (!zeGetRTASParallelOperationExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetRTASParallelOperationExpProcAddrTableResult = zeGetRTASParallelOperationExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASParallelOperationExp);
if (zeGetRTASParallelOperationExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetRTASParallelOperationExpProcAddrTableResult;

auto zeGetDriverProcAddrTable = reinterpret_cast<ze_pfnGetDriverProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") );
if (!zeGetDriverProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDriverProcAddrTableResult = zeGetDriverProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Driver);
if (zeGetDriverProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDriverProcAddrTableResult;

auto zeGetDriverExpProcAddrTable = reinterpret_cast<ze_pfnGetDriverExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") );
if (!zeGetDriverExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDriverExpProcAddrTableResult = zeGetDriverExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DriverExp);
if (zeGetDriverExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDriverExpProcAddrTableResult;

auto zeGetDeviceProcAddrTable = reinterpret_cast<ze_pfnGetDeviceProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") );
if (!zeGetDeviceProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDeviceProcAddrTableResult = zeGetDeviceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Device);
if (zeGetDeviceProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDeviceProcAddrTableResult;

auto zeGetDeviceExpProcAddrTable = reinterpret_cast<ze_pfnGetDeviceExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") );
if (!zeGetDeviceExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDeviceExpProcAddrTableResult = zeGetDeviceExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DeviceExp);
if (zeGetDeviceExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDeviceExpProcAddrTableResult;

auto zeGetContextProcAddrTable = reinterpret_cast<ze_pfnGetContextProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") );
if (!zeGetContextProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetContextProcAddrTableResult = zeGetContextProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Context);
if (zeGetContextProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetContextProcAddrTableResult;

auto zeGetCommandQueueProcAddrTable = reinterpret_cast<ze_pfnGetCommandQueueProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") );
if (!zeGetCommandQueueProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandQueueProcAddrTableResult = zeGetCommandQueueProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandQueue);
if (zeGetCommandQueueProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandQueueProcAddrTableResult;

auto zeGetCommandListProcAddrTable = reinterpret_cast<ze_pfnGetCommandListProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") );
if (!zeGetCommandListProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandListProcAddrTableResult = zeGetCommandListProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandList);
if (zeGetCommandListProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandListProcAddrTableResult;

auto zeGetCommandListExpProcAddrTable = reinterpret_cast<ze_pfnGetCommandListExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") );
if (!zeGetCommandListExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandListExpProcAddrTableResult = zeGetCommandListExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandListExp);
if (zeGetCommandListExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandListExpProcAddrTableResult;

auto zeGetEventProcAddrTable = reinterpret_cast<ze_pfnGetEventProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") );
if (!zeGetEventProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventProcAddrTableResult = zeGetEventProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Event);
if (zeGetEventProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventProcAddrTableResult;

auto zeGetEventExpProcAddrTable = reinterpret_cast<ze_pfnGetEventExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") );
if (!zeGetEventExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventExpProcAddrTableResult = zeGetEventExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventExp);
if (zeGetEventExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventExpProcAddrTableResult;

auto zeGetEventPoolProcAddrTable = reinterpret_cast<ze_pfnGetEventPoolProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") );
if (!zeGetEventPoolProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventPoolProcAddrTableResult = zeGetEventPoolProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventPool);
if (zeGetEventPoolProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventPoolProcAddrTableResult;

auto zeGetFenceProcAddrTable = reinterpret_cast<ze_pfnGetFenceProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") );
if (!zeGetFenceProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFenceProcAddrTableResult = zeGetFenceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Fence);
if (zeGetFenceProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFenceProcAddrTableResult;

auto zeGetImageProcAddrTable = reinterpret_cast<ze_pfnGetImageProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") );
if (!zeGetImageProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetImageProcAddrTableResult = zeGetImageProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Image);
if (zeGetImageProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetImageProcAddrTableResult;

auto zeGetImageExpProcAddrTable = reinterpret_cast<ze_pfnGetImageExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") );
if (!zeGetImageExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetImageExpProcAddrTableResult = zeGetImageExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ImageExp);
if (zeGetImageExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetImageExpProcAddrTableResult;

auto zeGetKernelProcAddrTable = reinterpret_cast<ze_pfnGetKernelProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") );
if (!zeGetKernelProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetKernelProcAddrTableResult = zeGetKernelProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Kernel);
if (zeGetKernelProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetKernelProcAddrTableResult;

auto zeGetKernelExpProcAddrTable = reinterpret_cast<ze_pfnGetKernelExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") );
if (!zeGetKernelExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetKernelExpProcAddrTableResult = zeGetKernelExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.KernelExp);
if (zeGetKernelExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetKernelExpProcAddrTableResult;

auto zeGetMemProcAddrTable = reinterpret_cast<ze_pfnGetMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") );
if (!zeGetMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetMemProcAddrTableResult = zeGetMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Mem);
if (zeGetMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetMemProcAddrTableResult;

auto zeGetMemExpProcAddrTable = reinterpret_cast<ze_pfnGetMemExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") );
if (!zeGetMemExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetMemExpProcAddrTableResult = zeGetMemExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.MemExp);
if (zeGetMemExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetMemExpProcAddrTableResult;

auto zeGetModuleProcAddrTable = reinterpret_cast<ze_pfnGetModuleProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") );
if (!zeGetModuleProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetModuleProcAddrTableResult = zeGetModuleProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Module);
if (zeGetModuleProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetModuleProcAddrTableResult;

auto zeGetModuleBuildLogProcAddrTable = reinterpret_cast<ze_pfnGetModuleBuildLogProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") );
if (!zeGetModuleBuildLogProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetModuleBuildLogProcAddrTableResult = zeGetModuleBuildLogProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ModuleBuildLog);
if (zeGetModuleBuildLogProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetModuleBuildLogProcAddrTableResult;

auto zeGetPhysicalMemProcAddrTable = reinterpret_cast<ze_pfnGetPhysicalMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") );
if (!zeGetPhysicalMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetPhysicalMemProcAddrTableResult = zeGetPhysicalMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.PhysicalMem);
if (zeGetPhysicalMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetPhysicalMemProcAddrTableResult;

auto zeGetSamplerProcAddrTable = reinterpret_cast<ze_pfnGetSamplerProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") );
if (!zeGetSamplerProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetSamplerProcAddrTableResult = zeGetSamplerProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Sampler);
if (zeGetSamplerProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetSamplerProcAddrTableResult;

auto zeGetVirtualMemProcAddrTable = reinterpret_cast<ze_pfnGetVirtualMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") );
if (!zeGetVirtualMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetVirtualMemProcAddrTableResult = zeGetVirtualMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.VirtualMem);
if (zeGetVirtualMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetVirtualMemProcAddrTableResult;

auto zeGetFabricEdgeExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricEdgeExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") );
if (!zeGetFabricEdgeExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFabricEdgeExpProcAddrTableResult = zeGetFabricEdgeExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricEdgeExp);
if (zeGetFabricEdgeExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFabricEdgeExpProcAddrTableResult;

auto zeGetFabricVertexExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricVertexExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") );
if (!zeGetFabricVertexExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFabricVertexExpProcAddrTableResult = zeGetFabricVertexExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricVertexExp);
if (zeGetFabricVertexExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFabricVertexExpProcAddrTableResult;

auto initResult = drv.dditable.ze.Global.pfnInit(flags);
// Bail out if any drivers that previously succeeded fail
if (initResult != ZE_RESULT_SUCCESS)
return initResult;
}

return ZE_RESULT_SUCCESS;
}


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderTranslateHandleInternal(
Expand Down
5 changes: 5 additions & 0 deletions source/loader/ze_loader_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ zelLoaderGetVersionsInternal(
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelReloadDriversInternal(
ze_init_flags_t flags);


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderTranslateHandleInternal(
zel_handle_type_t handleType, //Handle type
Expand Down
6 changes: 4 additions & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ if(MSVC)
target_compile_options(tests PRIVATE "/MD$<$<CONFIG:Debug>:d>")
endif()

add_test(NAME tests COMMAND tests)
set_property(TEST tests PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
add_test(NAME tests_api_version COMMAND tests --gtest_filter=LoaderAPI.GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned)
set_property(TEST tests_api_version PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
add_test(NAME tests_api_reload COMMAND tests --gtest_filter=LoaderAPI.GivenInitWhenCallingzelReloadDriversThenDriversStillWork)
set_property(TEST tests_api_reload PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
26 changes: 26 additions & 0 deletions test/loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,30 @@ TEST(
}
}

TEST(
LoaderAPI,
GivenInitWhenCallingzelReloadDriversThenDriversStillWork
) {
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));

uint32_t count = 0;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, nullptr));
EXPECT_GT(count, 0);

std::vector<ze_driver_handle_t> hDrivers(count);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, hDrivers.data()));

for (auto &driver : hDrivers) {
ze_driver_properties_t driverProperties;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
}

EXPECT_EQ(ZE_RESULT_SUCCESS, zelReloadDrivers(0));

for (auto &driver : hDrivers) {
ze_driver_properties_t driverProperties;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
}
}

} // namespace
Loading