Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor around lib loading
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Aug 8, 2019
1 parent 51c4091 commit 34d03fd
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 173 deletions.
4 changes: 2 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
#include "mxnet/libinfo.h"
#include "mxnet/imperative.h"
#include "mxnet/lib_api.h"
#include "../initialize.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
#include "../operator/tvmop/op_module.h"
#include "../common/utils.h"
#include "../common/library.h"

using namespace mxnet;

Expand Down Expand Up @@ -95,7 +95,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
// Loads library and initializes it
int MXLoadLib(const char *path) {
API_BEGIN();
void *lib = load_lib(path);
void *lib = LibraryInitializer::Get()->lib_load(path);
if (!lib)
LOG(FATAL) << "Unable to load library";

Expand Down
98 changes: 0 additions & 98 deletions src/common/library.cc

This file was deleted.

40 changes: 0 additions & 40 deletions src/common/library.h

This file was deleted.

103 changes: 90 additions & 13 deletions src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,15 @@
#include <mxnet/engine.h>
#include "./engine/openmp.h"
#include "./operator/custom/custom-inl.h"
#include "./common/library.h"
#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
#endif // MXNET_USE_OPENCV
#include "common/utils.h"
#include "engine/openmp.h"

#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#include <windows.h>
#else
#include <dlfcn.h>
#endif

#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#include <windows.h>
/*!
* \brief Retrieve the system error message for the last-error code
* \param err string that gets the error message
Expand All @@ -58,9 +53,10 @@ void win_err(char **err) {
reinterpret_cast<char*>(err),
0, NULL);
}
#else
#include <dlfcn.h>
#endif


namespace mxnet {

#if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE
Expand Down Expand Up @@ -106,6 +102,91 @@ LibraryInitializer::~LibraryInitializer() {
close_open_libs();
}

bool LibraryInitializer::lib_is_loaded(const std::string& path) const {
return loaded_libs.count(path) > 0;
}

/*!
* \brief Loads the dynamic shared library file
* \param path library file location
* \return handle a pointer for the loaded library, throws dmlc::error if library can't be loaded
*/
void* LibraryInitializer::lib_load(const char* path) {
void *handle = nullptr;
// check if library was already loaded
if (!lib_is_loaded(path)) {
// if not, load it
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
handle = LoadLibrary(path);
if (!handle) {
char *err_msg = nullptr;
win_err(&err_msg);
LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg;
LocalFree(err_msg);
return nullptr;
}
#else
handle = dlopen(path, RTLD_LAZY);
if (!handle) {
LOG(FATAL) << "Error loading library: '" << path << "'\n" << dlerror();
return nullptr;
}
#endif // _WIN32 or _WIN64 or __WINDOWS__
// then store the pointer to the library
loaded_libs[path] = handle;
} else {
loaded_libs.at(path);
}
return handle;
}

/*!
* \brief Closes the loaded dynamic shared library file
* \param handle library file handle
*/
void LibraryInitializer::lib_close(void* handle) {
std::string libpath;
for (const auto& l: loaded_libs) {
if (l.second == handle) {
libpath = l.first;
break;
}
}
CHECK(!libpath.empty());
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
FreeLibrary((HMODULE)handle);
#else
if (dlclose(handle)) {
LOG(WARNING) << "LibraryInitializer::lib_close: couldn't close library at address: " << handle
<< " loaded from: '" << libpath << "': " << dlerror();
}
#endif // _WIN32 or _WIN64 or __WINDOWS__
loaded_libs.erase(libpath);
}

/*!
* \brief Obtains address of given function in the loaded library
* \param handle pointer for the loaded library
* \param func function pointer that gets output address
* \param name function name to be fetched
*/
void LibraryInitializer::get_sym(void* handle, void** func, char* name) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
*func = GetProcAddress((HMODULE)handle, name);
if (!(*func)) {
char *err_msg = nullptr;
win_err(&err_msg);
LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg;
LocalFree(err_msg);
}
#else
*func = dlsym(handle, name);
if (!(*func)) {
LOG(FATAL) << "Error getting function '" << name << "' from library\n" << dlerror();
}
#endif // _WIN32 or _WIN64 or __WINDOWS__
}

bool LibraryInitializer::was_forked() const {
return common::current_process_id() != original_pid_;
}
Expand Down Expand Up @@ -153,15 +234,11 @@ void LibraryInitializer::install_signal_handlers() {
}

void LibraryInitializer::close_open_libs() {
for (auto const& lib : loaded_libs) {
close_lib(lib.second);
for (const auto& l: loaded_libs) {
lib_close(l.second);
}
}

void LibraryInitializer::dynlib_defer_close(const std::string &path, void *handle) {
loaded_libs.emplace(path, handle);
}

/**
* Perform static initialization
*/
Expand Down
42 changes: 22 additions & 20 deletions src/initialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,14 @@
#include <cstdlib>
#include <string>
#include <map>
#include "dmlc/io.h"


#ifndef MXNET_INITIALIZE_H_
#define MXNET_INITIALIZE_H_

namespace mxnet {

/*!
* \brief fetches from the library a function pointer of any given datatype and name
* \param T a template parameter for data type of function pointer
* \param lib library handle
* \param func_name function name to search for in the library
* \return func a function pointer
*/
template<typename T>
T get_func(void *lib, char *func_name) {
T func;
get_sym(lib, reinterpret_cast<void**>(&func), func_name);
if (!func)
LOG(FATAL) << "Unable to get function '" << func_name << "' from library";
return func;
}


void pthread_atfork_prepare();
Expand All @@ -58,7 +45,7 @@ void pthread_atfork_child();
*/
class LibraryInitializer {
public:
typedef static std::map<std::string, void*> loaded_libs_t;
typedef std::map<std::string, void*> loaded_libs_t;
static LibraryInitializer* Get() {
static LibraryInitializer inst;
return &inst;
Expand All @@ -79,8 +66,10 @@ class LibraryInitializer {


// Library loading
void lib_defer_close(const std::string& path, void* handle);
void lib_is_loaded()
bool lib_is_loaded(const std::string& path) const;
void* lib_load(const char* path);
void lib_close(void* handle);
static void get_sym(void* handle, void** func, char* name);

/**
* Original pid of the process which first loaded and initialized the library
Expand All @@ -92,7 +81,6 @@ class LibraryInitializer {
size_t mp_cv_num_threads_;

// Actual code for the atfork handlers as member functions.

void atfork_prepare();
void atfork_parent();
void atfork_child();
Expand All @@ -115,10 +103,24 @@ class LibraryInitializer {

void close_open_libs();


loaded_libs_t loaded_libs;
};

/*!
* \brief fetches from the library a function pointer of any given datatype and name
* \param T a template parameter for data type of function pointer
* \param lib library handle
* \param func_name function name to search for in the library
* \return func a function pointer
*/
template<typename T>
T get_func(void *lib, char *func_name) {
T func;
LibraryInitializer::Get()->get_sym(lib, reinterpret_cast<void**>(&func), func_name);
if (!func)
LOG(FATAL) << "Unable to get function '" << func_name << "' from library";
return func;
}

} // namespace mxnet
#endif // MXNET_INITIALIZE_H_

0 comments on commit 34d03fd

Please sign in to comment.