Skip to content

Commit

Permalink
Add zelReloadDrivers(flags) API
Browse files Browse the repository at this point in the history
Provides a means to re-initialize all of the drivers' library handles
and DDI tables. The value of flags must match what was provided to
zeInit(flags).

Signed-off-by: Lisanna Dettwyler <[email protected]>
  • Loading branch information
lisanna-dettwyler committed Sep 17, 2024
1 parent 519eed2 commit dc926b2
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 2 deletions.
16 changes: 16 additions & 0 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ zelLoaderGetVersions(
#endif
}

ze_result_t ZE_APICALL
zelReloadDrivers(
ze_init_flags_t flags)
{
#ifdef DYNAMIC_LOAD_LOADER
if(nullptr == ze_lib::context->loader)
return ZE_RESULT_ERROR;
typedef ze_result_t (ZE_APICALL *zelReloadDriver_t)(ze_driver_handle_t hDriver);
auto reloadDrivers = reinterpret_cast<zelReloadDriver_t>(
GET_FUNCTION_PTR(ze_lib::context->loader, "zelReloadDriversInternal") );
return reloadDrivers(flags);
#else
return zelReloadDriversInternal(flags);
#endif
}


ze_result_t ZE_APICALL
zelLoaderTranslateHandle(
Expand Down
252 changes: 252 additions & 0 deletions source/loader/ze_loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,258 @@ zelLoaderGetVersionsInternal(
return ZE_RESULT_SUCCESS;
}

ZE_DLLEXPORT ze_result_t ZE_APICALL
zelReloadDriversInternal(
ze_init_flags_t flags)
{
for( auto& drv : loader::context->zeDrivers ) {
if(drv.initStatus != ZE_RESULT_SUCCESS)
continue;

if (drv.handle) {
auto free_result = FREE_DRIVER_LIBRARY( drv.handle );
auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result);
if (failure)
return ZE_RESULT_ERROR_UNINITIALIZED;
}

drv.handle = LOAD_DRIVER_LIBRARY( drv.name.c_str() );
if (NULL == drv.handle)
return ZE_RESULT_ERROR_UNINITIALIZED;

auto zeGetGlobalProcAddrTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") );
if (!zeGetGlobalProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetGlobalProcAddrTableResult = zeGetGlobalProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Global);
if (zeGetGlobalProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetGlobalProcAddrTableResult;

auto zeGetRTASBuilderExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASBuilderExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") );
if (!zeGetRTASBuilderExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetRTASBuilderExpProcAddrTableResult = zeGetRTASBuilderExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASBuilderExp);
if (zeGetRTASBuilderExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetRTASBuilderExpProcAddrTableResult;

auto zeGetRTASParallelOperationExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASParallelOperationExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") );
if (!zeGetRTASParallelOperationExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetRTASParallelOperationExpProcAddrTableResult = zeGetRTASParallelOperationExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASParallelOperationExp);
if (zeGetRTASParallelOperationExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetRTASParallelOperationExpProcAddrTableResult;

auto zeGetDriverProcAddrTable = reinterpret_cast<ze_pfnGetDriverProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") );
if (!zeGetDriverProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDriverProcAddrTableResult = zeGetDriverProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Driver);
if (zeGetDriverProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDriverProcAddrTableResult;

auto zeGetDriverExpProcAddrTable = reinterpret_cast<ze_pfnGetDriverExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") );
if (!zeGetDriverExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDriverExpProcAddrTableResult = zeGetDriverExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DriverExp);
if (zeGetDriverExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDriverExpProcAddrTableResult;

auto zeGetDeviceProcAddrTable = reinterpret_cast<ze_pfnGetDeviceProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") );
if (!zeGetDeviceProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDeviceProcAddrTableResult = zeGetDeviceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Device);
if (zeGetDeviceProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDeviceProcAddrTableResult;

auto zeGetDeviceExpProcAddrTable = reinterpret_cast<ze_pfnGetDeviceExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") );
if (!zeGetDeviceExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetDeviceExpProcAddrTableResult = zeGetDeviceExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DeviceExp);
if (zeGetDeviceExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetDeviceExpProcAddrTableResult;

auto zeGetContextProcAddrTable = reinterpret_cast<ze_pfnGetContextProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") );
if (!zeGetContextProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetContextProcAddrTableResult = zeGetContextProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Context);
if (zeGetContextProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetContextProcAddrTableResult;

auto zeGetCommandQueueProcAddrTable = reinterpret_cast<ze_pfnGetCommandQueueProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") );
if (!zeGetCommandQueueProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandQueueProcAddrTableResult = zeGetCommandQueueProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandQueue);
if (zeGetCommandQueueProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandQueueProcAddrTableResult;

auto zeGetCommandListProcAddrTable = reinterpret_cast<ze_pfnGetCommandListProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") );
if (!zeGetCommandListProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandListProcAddrTableResult = zeGetCommandListProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandList);
if (zeGetCommandListProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandListProcAddrTableResult;

auto zeGetCommandListExpProcAddrTable = reinterpret_cast<ze_pfnGetCommandListExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") );
if (!zeGetCommandListExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetCommandListExpProcAddrTableResult = zeGetCommandListExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandListExp);
if (zeGetCommandListExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetCommandListExpProcAddrTableResult;

auto zeGetEventProcAddrTable = reinterpret_cast<ze_pfnGetEventProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") );
if (!zeGetEventProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventProcAddrTableResult = zeGetEventProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Event);
if (zeGetEventProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventProcAddrTableResult;

auto zeGetEventExpProcAddrTable = reinterpret_cast<ze_pfnGetEventExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") );
if (!zeGetEventExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventExpProcAddrTableResult = zeGetEventExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventExp);
if (zeGetEventExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventExpProcAddrTableResult;

auto zeGetEventPoolProcAddrTable = reinterpret_cast<ze_pfnGetEventPoolProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") );
if (!zeGetEventPoolProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetEventPoolProcAddrTableResult = zeGetEventPoolProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventPool);
if (zeGetEventPoolProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetEventPoolProcAddrTableResult;

auto zeGetFenceProcAddrTable = reinterpret_cast<ze_pfnGetFenceProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") );
if (!zeGetFenceProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFenceProcAddrTableResult = zeGetFenceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Fence);
if (zeGetFenceProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFenceProcAddrTableResult;

auto zeGetImageProcAddrTable = reinterpret_cast<ze_pfnGetImageProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") );
if (!zeGetImageProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetImageProcAddrTableResult = zeGetImageProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Image);
if (zeGetImageProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetImageProcAddrTableResult;

auto zeGetImageExpProcAddrTable = reinterpret_cast<ze_pfnGetImageExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") );
if (!zeGetImageExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetImageExpProcAddrTableResult = zeGetImageExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ImageExp);
if (zeGetImageExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetImageExpProcAddrTableResult;

auto zeGetKernelProcAddrTable = reinterpret_cast<ze_pfnGetKernelProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") );
if (!zeGetKernelProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetKernelProcAddrTableResult = zeGetKernelProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Kernel);
if (zeGetKernelProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetKernelProcAddrTableResult;

auto zeGetKernelExpProcAddrTable = reinterpret_cast<ze_pfnGetKernelExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") );
if (!zeGetKernelExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetKernelExpProcAddrTableResult = zeGetKernelExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.KernelExp);
if (zeGetKernelExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetKernelExpProcAddrTableResult;

auto zeGetMemProcAddrTable = reinterpret_cast<ze_pfnGetMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") );
if (!zeGetMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetMemProcAddrTableResult = zeGetMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Mem);
if (zeGetMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetMemProcAddrTableResult;

auto zeGetMemExpProcAddrTable = reinterpret_cast<ze_pfnGetMemExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") );
if (!zeGetMemExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetMemExpProcAddrTableResult = zeGetMemExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.MemExp);
if (zeGetMemExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetMemExpProcAddrTableResult;

auto zeGetModuleProcAddrTable = reinterpret_cast<ze_pfnGetModuleProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") );
if (!zeGetModuleProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetModuleProcAddrTableResult = zeGetModuleProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Module);
if (zeGetModuleProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetModuleProcAddrTableResult;

auto zeGetModuleBuildLogProcAddrTable = reinterpret_cast<ze_pfnGetModuleBuildLogProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") );
if (!zeGetModuleBuildLogProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetModuleBuildLogProcAddrTableResult = zeGetModuleBuildLogProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ModuleBuildLog);
if (zeGetModuleBuildLogProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetModuleBuildLogProcAddrTableResult;

auto zeGetPhysicalMemProcAddrTable = reinterpret_cast<ze_pfnGetPhysicalMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") );
if (!zeGetPhysicalMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetPhysicalMemProcAddrTableResult = zeGetPhysicalMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.PhysicalMem);
if (zeGetPhysicalMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetPhysicalMemProcAddrTableResult;

auto zeGetSamplerProcAddrTable = reinterpret_cast<ze_pfnGetSamplerProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") );
if (!zeGetSamplerProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetSamplerProcAddrTableResult = zeGetSamplerProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Sampler);
if (zeGetSamplerProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetSamplerProcAddrTableResult;

auto zeGetVirtualMemProcAddrTable = reinterpret_cast<ze_pfnGetVirtualMemProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") );
if (!zeGetVirtualMemProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetVirtualMemProcAddrTableResult = zeGetVirtualMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.VirtualMem);
if (zeGetVirtualMemProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetVirtualMemProcAddrTableResult;

auto zeGetFabricEdgeExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricEdgeExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") );
if (!zeGetFabricEdgeExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFabricEdgeExpProcAddrTableResult = zeGetFabricEdgeExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricEdgeExp);
if (zeGetFabricEdgeExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFabricEdgeExpProcAddrTableResult;

auto zeGetFabricVertexExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricVertexExpProcAddrTable_t>(
GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") );
if (!zeGetFabricVertexExpProcAddrTable)
return ZE_RESULT_ERROR_UNINITIALIZED;
auto zeGetFabricVertexExpProcAddrTableResult = zeGetFabricVertexExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricVertexExp);
if (zeGetFabricVertexExpProcAddrTableResult != ZE_RESULT_SUCCESS)
return zeGetFabricVertexExpProcAddrTableResult;

auto initResult = drv.dditable.ze.Global.pfnInit(flags);
// Bail out if any drivers that previously succeeded fail
if (initResult != ZE_RESULT_SUCCESS)
return initResult;
}

return ZE_RESULT_SUCCESS;
}


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderTranslateHandleInternal(
Expand Down
5 changes: 5 additions & 0 deletions source/loader/ze_loader_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ zelLoaderGetVersionsInternal(
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelReloadDriversInternal(
ze_init_flags_t flags);


ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderTranslateHandleInternal(
zel_handle_type_t handleType, //Handle type
Expand Down
6 changes: 4 additions & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ if(MSVC)
target_compile_options(tests PRIVATE "/MD$<$<CONFIG:Debug>:d>")
endif()

add_test(NAME tests COMMAND tests)
set_property(TEST tests PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
add_test(NAME tests_api_version COMMAND tests --gtest-filter=LoaderAPI.GivenLevelZeroLoaderPresentWhenCallingzeGetLoaderVersionsAPIThenValidVersionIsReturned)
set_property(TEST tests_api_version PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
add_test(NAME tests_api_reload COMMAND tests --gtest-filter=LoaderAPI.GivenInitWhenCallingzelReloadDriversThenDriversStillWork)
set_property(TEST tests_api_reload PROPERTY ENVIRONMENT "ZE_ENABLE_NULL_DRIVER=1")
26 changes: 26 additions & 0 deletions test/loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,30 @@ TEST(
}
}

TEST(
LoaderAPI,
GivenInitWhenCallingzelReloadDriversThenDriversStillWork
) {
EXPECT_EQ(ZE_RESULT_SUCCESS, zeInit(0));

uint32_t count = 0;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, nullptr));
EXPECT_GT(count, 0);

std::vector<ze_driver_handle_t> hDrivers(count);
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGet(&count, hDrivers.data()));

for (auto &driver : hDrivers) {
ze_driver_properties_t driverProperties;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
}

EXPECT_EQ(ZE_RESULT_SUCCESS, zelReloadDrivers(0));

for (auto &driver : hDrivers) {
ze_driver_properties_t driverProperties;
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetProperties(driver, &driverProperties));
}
}

} // namespace

0 comments on commit dc926b2

Please sign in to comment.