diff --git a/Makefile b/Makefile index 02a74a067031..1c0649f29e9d 100644 --- a/Makefile +++ b/Makefile @@ -99,7 +99,7 @@ else CFLAGS += -O3 -DNDEBUG=1 endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) -LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) +LDFLAGS = -pthread -ldl $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) ifeq ($(ENABLE_TESTCOVERAGE), 1) CFLAGS += --coverage diff --git a/include/mxnet/base.h b/include/mxnet/base.h index b239cb1f7302..947c1fecdca2 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -37,6 +37,9 @@ #include "libinfo.h" #include "tuple.h" +#include "library.h" +#include "mxnet_acc.h" +#include /*! * \brief define compatible keywords in g++ @@ -98,15 +101,30 @@ typedef mshadow::default_real_t real_t; /*! \brief operator structure from NNVM */ using Op = nnvm::Op; +struct AccContext { + int kAccType; + std::string accName; + + /*! \brief default constructor */ + AccContext() : kAccType(0), accName("Error") {} + /*! \brief constructor */ + AccContext(int type, std::string name) : kAccType(type), accName(name) {} +}; + /*! \brief Context information about the execution environment */ struct Context { /*! \brief Type of device */ - enum DeviceType { - kCPU = cpu::kDevMask, - kGPU = gpu::kDevMask, - kCPUPinned = 3, - kCPUShared = 5, - }; + static const int kCPU = 1; + static const int kGPU = 2; + static const int kCPUPinned = 3; + static const int kCPUShared = 5; + static const int kAccBase = 10; + + typedef int DeviceType; + + static std::map acc_map; + static std::map acc_names; + /*! \brief the device type we run the op on */ DeviceType dev_type; /*! \brief device id we are going to run it on */ @@ -126,6 +144,7 @@ struct Context { */ inline int real_dev_id() const { if (dev_type == kCPUPinned || dev_type == kGPU) return dev_id; + else if (dev_type >= kAccBase) return dev_id; return 0; } /*! @@ -169,7 +188,7 @@ struct Context { return true; } /*! \brief the maximal device type */ - static const int32_t kMaxDevType = 6; + static const int32_t kMaxDevType = 20; /*! \brief the maximal device index */ static const int32_t kMaxDevID = 16; /*! @@ -222,6 +241,12 @@ struct Context { * \return Context */ inline static Context FromString(const std::string& str); + /*! + * Load accelerator from given path and get its name + * \param path of .so file and name of the accelerator + * \return No return value + */ + inline static int LoadAcc(const std::string& path, char *name); }; #if MXNET_USE_CUDA @@ -487,6 +512,9 @@ inline Context Context::FromString(const std::string& str) { ret = CPUPinned(id); } else if (type == "cpu_shared") { ret = CPUShared(id); + } else if (Context::acc_names.find(type) != Context::acc_names.end()) { + DeviceType dev_type = Context::acc_names[type]; + ret = Create(dev_type, id); } else { LOG(FATAL) << "Invalid context string " << str; } @@ -496,6 +524,34 @@ inline Context Context::FromString(const std::string& str) { return ret; } +inline int Context::LoadAcc(const std::string& path, char *name) { + // load library + void *lib = load_lib(path.c_str()); + if (!lib) + LOG(FATAL) << "Unable to load library"; + + // get name function from library + void (*getAccName)(char*); + get_func(lib, (void**)(&getAccName), const_cast("getAccName")); + if (!getAccName) + LOG(FATAL) << "Unable to get accelerator name from library"; + + // call name function + char accname[100]; + getAccName(accname); + std::string name_str(accname); + snprintf(name, 100, "%s", name_str.c_str()); + + // create entry for accelerator + int id = Context::kAccBase + Context::acc_map.size(); + AccContext ctx(id, name_str); + + // add accelerator context to map + Context::acc_map[id] = ctx; + Context::acc_names[name_str] = id; + return id; +} + inline std::ostream& operator<<(std::ostream &out, const Context &ctx) { if (ctx.dev_type == Context::kCPU) { out << "cpu("; @@ -505,6 +561,8 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) { out << "cpu_pinned("; } else if (ctx.dev_type == Context::kCPUShared) { out << "cpu_shared("; + } else if (Context::acc_map.find(ctx.dev_type) != Context::acc_map.end()) { + out << Context::acc_map[ctx.dev_type].accName << "("; } else { out << "unknown("; } diff --git a/include/mxnet/library.h b/include/mxnet/library.h new file mode 100644 index 000000000000..a0cf94acf7b8 --- /dev/null +++ b/include/mxnet/library.h @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file library.h + * \brief Defining accelerator loading functions + */ +#ifndef MXNET_LIBRARY_H_ +#define MXNET_LIBRARY_H_ + +#include +#include + +void* load_lib(const char* path); +void get_func(void* handle, void** func, char* name); + +#endif // MXNET_LIBRARY_H_ diff --git a/python/mxnet/context.py b/python/mxnet/context.py index f284e00127b4..eac78fdea99d 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -21,6 +21,7 @@ import threading import warnings import ctypes +import os from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass from .base import _LIB from .base import check_call @@ -71,6 +72,7 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)): _default_ctx = threading.local() devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} + acc_map = {} def __init__(self, device_type, device_id=0): if isinstance(device_type, Context): self.device_typeid = device_type.device_typeid @@ -167,6 +169,32 @@ def empty_cache(self): Context._default_ctx.value = Context('cpu', 0) +def load_acc(path): + #check if path exists + if not os.path.exists(path): + print('load_acc path "%s" does NOT exist' % path) + return None + #check if path is to a file + if not os.path.isfile(path): + print('load_acc path "%s" is NOT a library file' % path) + return None + + byt_obj = path.encode('utf-8') + chararr = ctypes.c_char_p(byt_obj) + dev_id = ctypes.c_int() + name = ctypes.create_string_buffer(100) + + check_call(_LIB.MXLoadAccLib(chararr,ctypes.byref(dev_id),name)) + + dev_id = dev_id.value + name = name.value + Context.devtype2str[dev_id] = name + Context.devstr2type[name] = dev_id + Context.acc_map[dev_id] = (name,path) + + return Context(name, 0) + + def cpu(device_id=0): """Returns a CPU context. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 35bd3eeb477a..1921d5f8fcf3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -86,9 +86,16 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2); API_END(); } - // NOTE: return value is added in API_END +int MXLoadAccLib(const char *path, int *id, char *name) { + API_BEGIN(); + std::string tmp(path); + int dev_id = mxnet::Context::LoadAcc(tmp, name); + *id = dev_id; + API_END(); +} + int MXLibInfoFeatures(const struct LibFeature **lib_features, size_t *size) { using namespace features; API_BEGIN(); diff --git a/src/common/base.cc b/src/common/base.cc new file mode 100644 index 000000000000..bcdce1cc71ab --- /dev/null +++ b/src/common/base.cc @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file base.cc + * \brief Defining map variables for the header file + */ +#include +#include + +#include "../../include/mxnet/base.h" + +std::map mxnet::Context::acc_map; +std::map mxnet::Context::acc_names; diff --git a/src/common/library.cc b/src/common/library.cc new file mode 100644 index 000000000000..a3fe47293059 --- /dev/null +++ b/src/common/library.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file library.cc + * \brief Dynamically loading accelerator library + * and accessing its functions + */ + +#include "../../include/mxnet/library.h" + +void* load_lib(const char* path) { + void *handle; + handle = dlopen(path, RTLD_LAZY); + + if (!handle) { + std::cerr << "Error loading accelerator library: '" << path + << "'\n" << dlerror() << std::endl; + return 0; + } + return handle; +} + +void get_func(void* handle, void** func, char* name) { + *reinterpret_cast(func) = dlsym(handle, name); + if (!func) { + std::cerr << "Error getting function '" << name + << "' from accelerator library\n" << dlerror() << std::endl; + } +}