diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5b23a8c694f6..6bc23e165bb2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -46,12 +46,12 @@ #include "mxnet/libinfo.h" #include "mxnet/imperative.h" #include "mxnet/lib_api.h" +#include "../initialize.h" #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tvmop/op_module.h" #include "../common/utils.h" -#include "../common/library.h" using namespace mxnet; @@ -95,7 +95,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, // Loads library and initializes it int MXLoadLib(const char *path) { API_BEGIN(); - void *lib = load_lib(path); + void *lib = LibraryInitializer::Get()->lib_load(path); if (!lib) LOG(FATAL) << "Unable to load library"; diff --git a/src/common/library.cc b/src/common/library.cc deleted file mode 100644 index f6ebd078b049..000000000000 --- a/src/common/library.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file library.cc - * \brief Dynamically loading accelerator library - * and accessing its functions - */ - - - -/*! - * \brief Loads the dynamic shared library file - * \param path library file location - * \return handle a pointer for the loaded library, nullptr if loading unsuccessful - */ -void* load_lib(const char* path) { - void *handle = nullptr; - std::string path_str(path); - // check if library was already loaded - if (loaded_libs.find(path_str) == loaded_libs.end()) { - // if not, load it -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - handle = LoadLibrary(path); - if (!handle) { - char *err_msg = nullptr; - win_err(&err_msg); - LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg; - LocalFree(err_msg); - return nullptr; - } -#else - handle = dlopen(path, RTLD_LAZY); - if (!handle) { - LOG(FATAL) << "Error loading library: '" << path << "'\n" << dlerror(); - return nullptr; - } -#endif // _WIN32 or _WIN64 or __WINDOWS__ - // then store the pointer to the library - loaded_libs[path_str] = handle; - } else { - // otherwise just look up the pointer - handle = loaded_libs[path_str]; - } - return handle; -} - -/*! - * \brief Closes the loaded dynamic shared library file - * \param handle library file handle - */ -void close_lib(void* handle) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - FreeLibrary((HMODULE)handle); -#else - dlclose(handle); -#endif // _WIN32 or _WIN64 or __WINDOWS__ -} - -/*! - * \brief Obtains address of given function in the loaded library - * \param handle pointer for the loaded library - * \param func function pointer that gets output address - * \param name function name to be fetched - */ -void get_sym(void* handle, void** func, char* name) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - *func = GetProcAddress((HMODULE)handle, name); - if (!(*func)) { - char *err_msg = nullptr; - win_err(&err_msg); - LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg; - LocalFree(err_msg); - } -#else - *func = dlsym(handle, name); - if (!(*func)) { - LOG(FATAL) << "Error getting function '" << name << "' from library\n" << dlerror(); - } -#endif // _WIN32 or _WIN64 or __WINDOWS__ -} diff --git a/src/common/library.h b/src/common/library.h deleted file mode 100644 index bc0914d398c6..000000000000 --- a/src/common/library.h +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2015 by Contributors - * \file library.h - * \brief Defining library loading functions - */ -#ifndef MXNET_COMMON_LIBRARY_H_ -#define MXNET_COMMON_LIBRARY_H_ - -#include -#include -#include -#include "dmlc/io.h" - - -void* load_lib(const char* path); -void close_lib(void* handle); -void get_sym(void* handle, void** func, char* name); - - - -#endif // MXNET_COMMON_LIBRARY_H_ diff --git a/src/initialize.cc b/src/initialize.cc index 1bd07da007ed..a8fa50af222a 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -28,20 +28,15 @@ #include #include "./engine/openmp.h" #include "./operator/custom/custom-inl.h" -#include "./common/library.h" #if MXNET_USE_OPENCV #include #endif // MXNET_USE_OPENCV #include "common/utils.h" #include "engine/openmp.h" -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) -#include -#else -#include -#endif #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) +#include /*! * \brief Retrieve the system error message for the last-error code * \param err string that gets the error message @@ -58,9 +53,10 @@ void win_err(char **err) { reinterpret_cast(err), 0, NULL); } +#else +#include #endif - namespace mxnet { #if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE @@ -106,6 +102,91 @@ LibraryInitializer::~LibraryInitializer() { close_open_libs(); } +bool LibraryInitializer::lib_is_loaded(const std::string& path) const { + return loaded_libs.count(path) > 0; +} + +/*! + * \brief Loads the dynamic shared library file + * \param path library file location + * \return handle a pointer for the loaded library, throws dmlc::error if library can't be loaded + */ +void* LibraryInitializer::lib_load(const char* path) { + void *handle = nullptr; + // check if library was already loaded + if (!lib_is_loaded(path)) { + // if not, load it +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + handle = LoadLibrary(path); + if (!handle) { + char *err_msg = nullptr; + win_err(&err_msg); + LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg; + LocalFree(err_msg); + return nullptr; + } +#else + handle = dlopen(path, RTLD_LAZY); + if (!handle) { + LOG(FATAL) << "Error loading library: '" << path << "'\n" << dlerror(); + return nullptr; + } +#endif // _WIN32 or _WIN64 or __WINDOWS__ + // then store the pointer to the library + loaded_libs[path] = handle; + } else { + loaded_libs.at(path); + } + return handle; +} + +/*! + * \brief Closes the loaded dynamic shared library file + * \param handle library file handle + */ +void LibraryInitializer::lib_close(void* handle) { + std::string libpath; + for (const auto& l: loaded_libs) { + if (l.second == handle) { + libpath = l.first; + break; + } + } + CHECK(!libpath.empty()); +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + FreeLibrary((HMODULE)handle); +#else + if (dlclose(handle)) { + LOG(WARNING) << "LibraryInitializer::lib_close: couldn't close library at address: " << handle + << " loaded from: '" << libpath << "': " << dlerror(); + } +#endif // _WIN32 or _WIN64 or __WINDOWS__ + loaded_libs.erase(libpath); +} + +/*! + * \brief Obtains address of given function in the loaded library + * \param handle pointer for the loaded library + * \param func function pointer that gets output address + * \param name function name to be fetched + */ +void LibraryInitializer::get_sym(void* handle, void** func, char* name) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + *func = GetProcAddress((HMODULE)handle, name); + if (!(*func)) { + char *err_msg = nullptr; + win_err(&err_msg); + LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg; + LocalFree(err_msg); + } +#else + *func = dlsym(handle, name); + if (!(*func)) { + LOG(FATAL) << "Error getting function '" << name << "' from library\n" << dlerror(); + } +#endif // _WIN32 or _WIN64 or __WINDOWS__ +} + bool LibraryInitializer::was_forked() const { return common::current_process_id() != original_pid_; } @@ -153,15 +234,11 @@ void LibraryInitializer::install_signal_handlers() { } void LibraryInitializer::close_open_libs() { - for (auto const& lib : loaded_libs) { - close_lib(lib.second); + for (const auto& l: loaded_libs) { + lib_close(l.second); } } -void LibraryInitializer::dynlib_defer_close(const std::string &path, void *handle) { - loaded_libs.emplace(path, handle); -} - /** * Perform static initialization */ diff --git a/src/initialize.h b/src/initialize.h index 240e5a9f9390..8a6dc3aa5f7f 100644 --- a/src/initialize.h +++ b/src/initialize.h @@ -26,27 +26,14 @@ #include #include #include +#include "dmlc/io.h" + #ifndef MXNET_INITIALIZE_H_ #define MXNET_INITIALIZE_H_ namespace mxnet { -/*! - * \brief fetches from the library a function pointer of any given datatype and name - * \param T a template parameter for data type of function pointer - * \param lib library handle - * \param func_name function name to search for in the library - * \return func a function pointer - */ -template -T get_func(void *lib, char *func_name) { - T func; - get_sym(lib, reinterpret_cast(&func), func_name); - if (!func) - LOG(FATAL) << "Unable to get function '" << func_name << "' from library"; - return func; -} void pthread_atfork_prepare(); @@ -58,7 +45,7 @@ void pthread_atfork_child(); */ class LibraryInitializer { public: - typedef static std::map loaded_libs_t; + typedef std::map loaded_libs_t; static LibraryInitializer* Get() { static LibraryInitializer inst; return &inst; @@ -79,8 +66,10 @@ class LibraryInitializer { // Library loading - void lib_defer_close(const std::string& path, void* handle); - void lib_is_loaded() + bool lib_is_loaded(const std::string& path) const; + void* lib_load(const char* path); + void lib_close(void* handle); + static void get_sym(void* handle, void** func, char* name); /** * Original pid of the process which first loaded and initialized the library @@ -92,7 +81,6 @@ class LibraryInitializer { size_t mp_cv_num_threads_; // Actual code for the atfork handlers as member functions. - void atfork_prepare(); void atfork_parent(); void atfork_child(); @@ -115,10 +103,24 @@ class LibraryInitializer { void close_open_libs(); - loaded_libs_t loaded_libs; }; +/*! + * \brief fetches from the library a function pointer of any given datatype and name + * \param T a template parameter for data type of function pointer + * \param lib library handle + * \param func_name function name to search for in the library + * \return func a function pointer + */ +template +T get_func(void *lib, char *func_name) { + T func; + LibraryInitializer::Get()->get_sym(lib, reinterpret_cast(&func), func_name); + if (!func) + LOG(FATAL) << "Unable to get function '" << func_name << "' from library"; + return func; +} } // namespace mxnet #endif // MXNET_INITIALIZE_H_