diff --git a/CMakeLists.txt b/CMakeLists.txt index 42463cd227c6..180955bd662f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -692,6 +692,8 @@ else() endif() +add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/lib_api/mylib.cc) +target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) set(MXNET_INSTALL_TARGETS mxnet) if(UNIX) # Create dummy file since we want an empty shared library before linking @@ -702,9 +704,13 @@ if(UNIX) add_library(mxnet SHARED ${DUMMY_SOURCE}) target_link_libraries(mxnet PRIVATE ${BEGIN_WHOLE_ARCHIVE} $ ${END_WHOLE_ARCHIVE}) target_link_libraries(mxnet PRIVATE mxnet_static) + target_link_libraries(mxnet_static PUBLIC ${CMAKE_DL_LIBS}) + target_compile_options(sample_lib PUBLIC -shared) set_target_properties(mxnet_static PROPERTIES OUTPUT_NAME mxnet) else() add_library(mxnet SHARED ${SOURCE}) + target_compile_options(sample_lib PUBLIC /LD) + set_target_properties(sample_lib PROPERTIES PREFIX "lib") endif() if(USE_CUDA) diff --git a/Makefile b/Makefile index c8d4e35c80ec..f166b4e9f94b 100644 --- a/Makefile +++ b/Makefile @@ -107,7 +107,7 @@ else CFLAGS += -O3 -DNDEBUG=1 endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) -LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) +LDFLAGS = -pthread -ldl $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) ifeq ($(ENABLE_TESTCOVERAGE), 1) CFLAGS += --coverage @@ -453,7 +453,7 @@ endif .PHONY: clean all extra-packages test lint docs clean_all rcpplint rcppexport roxygen\ cython2 cython3 cython cyclean -all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages +all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) OBJ = $(patsubst %.cc, build/%.o, $(SRC)) @@ -658,6 +658,9 @@ cpplint: pylint: python3 -m pylint --rcfile=$(ROOTDIR)/ci/other/pylintrc --ignore-patterns=".*\.so$$,.*\.dll$$,.*\.dylib$$" python/mxnet tools/caffe_converter/*.py +sample_lib: + $(CXX) -shared -fPIC example/lib_api/mylib.cc -o libsample_lib.so -I include/mxnet + doc: docs docs: diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 668d2f7c7dca..d4f0a3d48096 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -23,8 +23,8 @@ utils = load('ci/Jenkinsfile_utils.groovy') // mxnet libraries -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // Python wheels mx_pip = 'build/*.whl' @@ -33,11 +33,11 @@ mx_pip = 'build/*.whl' mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. -mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' +mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' -mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/*' // Python unittest for CPU diff --git a/example/lib_api/Makefile b/example/lib_api/Makefile new file mode 100644 index 000000000000..e5893c8065c4 --- /dev/null +++ b/example/lib_api/Makefile @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +all: + g++ -shared -fPIC mylib.cc -o mylib.so -I ../../include/mxnet + +test: + g++ -std=c++11 -O3 -o libtest libtest.cc -ldl -I ../../include/mxnet + +windows: + cl /LD mylib.cc + +win_test: + cl libtest.cc + +clean: + rm -rf mylib.so libtest diff --git a/example/lib_api/libtest.cc b/example/lib_api/libtest.cc new file mode 100644 index 000000000000..8bdf36c05d37 --- /dev/null +++ b/example/lib_api/libtest.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file libtest.cc + * \brief This test checks if the library is implemented correctly + * and does not involve dynamic loading of library into MXNet + * This test is supposed to be run before test.py + */ + +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) +#include +#else +#include +#endif + +#include +#include "lib_api.h" + +#define MXNET_VERSION 10500 + +int main(void) { + // Get a handle to the library. +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + HINSTANCE handle; + handle = LoadLibrary(TEXT("mylib.dll")); +#else + void *handle; + handle = dlopen("mylib.so", RTLD_LAZY); +#endif + + if (!handle) { + std::cerr << "Unable to load library" << std::endl; + return 1; + } + + // get initialize function address from the library + initialize_t init_lib; +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + init_lib = (initialize_t) GetProcAddress(handle, MXLIB_INITIALIZE_STR); +#else + init_lib = (initialize_t) dlsym(handle, MXLIB_INITIALIZE_STR); +#endif + + if (!init_lib) { + std::cerr << "Unable to get function 'intialize' from library" << std::endl; + return 1; + } + + // Call the function. + (init_lib)(MXNET_VERSION); + + // Deallocate memory. +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + FreeLibrary(handle); +#else + dlclose(handle); +#endif + + return 0; +} diff --git a/example/lib_api/mylib.cc b/example/lib_api/mylib.cc new file mode 100644 index 000000000000..e67560a87f3d --- /dev/null +++ b/example/lib_api/mylib.cc @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file mylib.cc + * \brief Sample library file + */ + +#include +#include "lib_api.h" + +int initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return 1; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return 0; + } +} diff --git a/example/lib_api/test.py b/example/lib_api/test.py new file mode 100644 index 000000000000..d73d85c02ced --- /dev/null +++ b/example/lib_api/test.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=arguments-differ + +# This test checks if dynamic loading of library into MXNet is successful + +import mxnet as mx +import os + +if (os.name=='posix'): + mx.library.load('mylib.so') +elif (os.name=='nt'): + mx.library.load('mylib.dll') diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index ff3e689420ce..95d13fe2125c 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -226,6 +226,13 @@ MXNET_DLL const char *MXGetLastError(); // Part 0: Global State setups //------------------------------------- +/*! + * \brief Load library dynamically + * \param path to the library .so file + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXLoadLib(const char *path); + /*! * \brief Get list of features supported on the runtime * \param libFeature pointer to array of LibFeature diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h new file mode 100644 index 000000000000..ca3b2952eafa --- /dev/null +++ b/include/mxnet/lib_api.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file lib_api.h + * \brief APIs to interact with libraries + */ +#ifndef MXNET_LIB_API_H_ +#define MXNET_LIB_API_H_ + +/*! + * \brief Following are the APIs implemented in the external library + * Each API has a #define string that is used to lookup the function in the library + * Followed by the function declaration + */ +#define MXLIB_INITIALIZE_STR "initialize" +typedef int (*initialize_t)(int); + +extern "C" { + /*! + * \brief Checks if the MXNet version is supported by the library. + * If supported, initializes the library. + * \param version MXNet version number passed to library and defined as: + * MXNET_VERSION = (MXNET_MAJOR*10000 + MXNET_MINOR*100 + MXNET_PATCH) + * \return Non-zero value on error i.e. library incompatible with passed MXNet version + */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl initialize(int); +#else + int initialize(int); +#endif +} +#endif // MXNET_LIB_API_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index ab4bffde28a9..233bb2a1f57e 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -26,6 +26,7 @@ from .base import MXNetError from .util import is_np_shape, set_np_shape, np_shape, use_np_shape from . import base +from . import library from . import contrib from . import ndarray from . import ndarray as nd diff --git a/python/mxnet/base.py b/python/mxnet/base.py index bf8026359d02..17819bde28b2 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -86,7 +86,7 @@ def __repr__(self): class MXNetError(Exception): - """Error that will be throwed by all mxnet functions.""" + """Error that will be thrown by all mxnet functions.""" pass diff --git a/python/mxnet/library.py b/python/mxnet/library.py new file mode 100644 index 000000000000..9ebf2c2bc580 --- /dev/null +++ b/python/mxnet/library.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""Library management API of mxnet.""" +from __future__ import absolute_import +import ctypes +import os +from .base import _LIB, check_call, MXNetError + +def load(path): + """Loads library dynamically. + + Parameters + --------- + path : Path to library .so/.dll file + + Returns + --------- + void + """ + #check if path exists + if not os.path.exists(path): + raise MXNetError("load path %s does NOT exist" % path) + #check if path is an absolute path + if not os.path.isabs(path): + raise MXNetError("load path %s is not an absolute path" % path) + #check if path is to a library file + _, file_ext = os.path.splitext(path) + if not file_ext in ['.so', '.dll']: + raise MXNetError("load path %s is NOT a library file" % path) + + byt_obj = path.encode('utf-8') + chararr = ctypes.c_char_p(byt_obj) + check_call(_LIB.MXLoadLib(chararr)) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a7a80a5ab40c..ffe6d8dcdbdc 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -45,11 +45,13 @@ #include "mxnet/storage.h" #include "mxnet/libinfo.h" #include "mxnet/imperative.h" +#include "mxnet/lib_api.h" #include "./c_api_common.h" #include "../operator/custom/custom-inl.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tvmop/op_module.h" #include "../common/utils.h" +#include "../common/library.h" using namespace mxnet; @@ -90,6 +92,19 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, // NOTE: return value is added in API_END +// Loads library and initializes it +int MXLoadLib(const char *path) { + API_BEGIN(); + void *lib = load_lib(path); + if (!lib) + LOG(FATAL) << "Unable to load library"; + + initialize_t initialize = get_func(lib, const_cast(MXLIB_INITIALIZE_STR)); + if (!initialize(static_cast(MXNET_VERSION))) + LOG(FATAL) << "Library failed to initialize"; + API_END(); +} + int MXLibInfoFeatures(const struct LibFeature **lib_features, size_t *size) { using namespace features; API_BEGIN(); diff --git a/src/common/library.cc b/src/common/library.cc new file mode 100644 index 000000000000..9e79b5dbe1bc --- /dev/null +++ b/src/common/library.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file library.cc + * \brief Dynamically loading accelerator library + * and accessing its functions + */ + +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) +#include +#else +#include +#endif + +#include +#include "library.h" + +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) +/*! + * \brief Retrieve the system error message for the last-error code + * \param err string that gets the error message + */ +void win_err(char **err) { + uint32_t dw = GetLastError(); + FormatMessage( + FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + dw, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(err), + 0, NULL); +} +#endif + + +/*! + * \brief Loads the dynamic shared library file + * \param path library file location + * \return handle a pointer for the loaded library, nullptr if loading unsuccessful + */ +void* load_lib(const char* path) { + void *handle = nullptr; + std::string path_str(path); + // check if library was already loaded + if (loaded_libs.find(path_str) == loaded_libs.end()) { + // if not, load it +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + handle = LoadLibrary(path); + if (!handle) { + char *err_msg = nullptr; + win_err(&err_msg); + LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg; + LocalFree(err_msg); + return nullptr; + } +#else + handle = dlopen(path, RTLD_LAZY); + if (!handle) { + LOG(FATAL) << "Error loading library: '" << path << "'\n" << dlerror(); + return nullptr; + } +#endif // _WIN32 or _WIN64 or __WINDOWS__ + // then store the pointer to the library + loaded_libs[path_str] = handle; + } else { + // otherwise just look up the pointer + handle = loaded_libs[path_str]; + } + return handle; +} + +/*! + * \brief Closes the loaded dynamic shared library file + * \param handle library file handle + */ +void close_lib(void* handle) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + FreeLibrary((HMODULE)handle); +#else + dlclose(handle); +#endif // _WIN32 or _WIN64 or __WINDOWS__ +} + +/*! + * \brief Obtains address of given function in the loaded library + * \param handle pointer for the loaded library + * \param func function pointer that gets output address + * \param name function name to be fetched + */ +void get_sym(void* handle, void** func, char* name) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + *func = GetProcAddress((HMODULE)handle, name); + if (!(*func)) { + char *err_msg = nullptr; + win_err(&err_msg); + LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg; + LocalFree(err_msg); + } +#else + *func = dlsym(handle, name); + if (!(*func)) { + LOG(FATAL) << "Error getting function '" << name << "' from library\n" << dlerror(); + } +#endif // _WIN32 or _WIN64 or __WINDOWS__ +} diff --git a/src/common/library.h b/src/common/library.h new file mode 100644 index 000000000000..d6eff4184191 --- /dev/null +++ b/src/common/library.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file library.h + * \brief Defining library loading functions + */ +#ifndef MXNET_COMMON_LIBRARY_H_ +#define MXNET_COMMON_LIBRARY_H_ + +#include +#include +#include +#include "dmlc/io.h" + +// map of libraries loaded +static std::map loaded_libs; + +void* load_lib(const char* path); +void close_lib(void* handle); +void get_sym(void* handle, void** func, char* name); + +/*! + * \brief a templated function that fetches from the library + * a function pointer of any given datatype and name + * \param T a template parameter for data type of function pointer + * \param lib library handle + * \param func_name function name to search for in the library + * \return func a function pointer + */ +template +T get_func(void *lib, char *func_name) { + T func; + get_sym(lib, reinterpret_cast(&func), func_name); + if (!func) + LOG(FATAL) << "Unable to get function '" << func_name << "' from library"; + return func; +} + +#endif // MXNET_COMMON_LIBRARY_H_ diff --git a/src/initialize.cc b/src/initialize.cc index 7236ced52e93..04952be7072a 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -27,6 +27,7 @@ #include #include "./engine/openmp.h" #include "./operator/custom/custom-inl.h" +#include "./common/library.h" #if MXNET_USE_OPENCV #include #endif // MXNET_USE_OPENCV @@ -80,6 +81,13 @@ class LibraryInitializer { #endif } + ~LibraryInitializer() { + // close opened libraries + for (auto const& lib : loaded_libs) { + close_lib(lib.second); + } + } + static LibraryInitializer* Get(); }; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index dcd6acce8ca5..e4ec98f9f1bd 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -45,6 +45,7 @@ from test_subgraph_op import * from test_contrib_operator import test_multibox_target_op from test_tvm_op import * +from test_library_loading import * set_default_context(mx.gpu(0)) del test_support_vector_machine_l1_svm # noqa diff --git a/tests/python/unittest/test_library_loading.py b/tests/python/unittest/test_library_loading.py new file mode 100644 index 000000000000..596d124c89d6 --- /dev/null +++ b/tests/python/unittest/test_library_loading.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This test checks if dynamic loading of library into MXNet is successful + +import os +import platform +import unittest +import mxnet as mx +from mxnet.base import MXNetError +from mxnet.test_utils import download + +def check_platform(): + return platform.machine() not in ['x86_64', 'AMD64'] + +@unittest.skipIf(check_platform(), "not all machine types supported") +def test_library_loading(): + if (os.name=='posix'): + lib = 'libsample_lib.so' + if os.path.exists(lib): + fname = lib + elif os.path.exists('build/'+lib): + fname = 'build/'+lib + else: + raise MXNetError("library %s not found " % lib) + elif (os.name=='nt'): + lib = 'libsample_lib.dll' + if os.path.exists('windows_package\\lib\\'+lib): + fname = 'windows_package\\lib\\'+lib + else: + raise MXNetError("library %s not found " % lib) + + fname = os.path.abspath(fname) + mx.library.load(fname)