Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding basic functionality to load accelerator library
Browse files Browse the repository at this point in the history
  • Loading branch information
mseth10 committed Jul 12, 2019
1 parent 94fda6e commit 671df6f
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ else
CFLAGS += -O3 -DNDEBUG=1
endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
LDFLAGS = -pthread -ldl $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)

ifeq ($(ENABLE_TESTCOVERAGE), 1)
CFLAGS += --coverage
Expand Down
72 changes: 65 additions & 7 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include "libinfo.h"
#include "tuple.h"

#include "library.h"
#include "mxnet_acc.h"
#include <map>

/*!
* \brief define compatible keywords in g++
Expand Down Expand Up @@ -98,15 +101,30 @@ typedef mshadow::default_real_t real_t;
/*! \brief operator structure from NNVM */
using Op = nnvm::Op;

struct AccContext {
int kAccType;
std::string accName;

/*! \brief default constructor */
AccContext() : kAccType(0), accName("Error") {}
/*! \brief constructor */
AccContext(int type, std::string name) : kAccType(type), accName(name) {}
};

/*! \brief Context information about the execution environment */
struct Context {
/*! \brief Type of device */
enum DeviceType {
kCPU = cpu::kDevMask,
kGPU = gpu::kDevMask,
kCPUPinned = 3,
kCPUShared = 5,
};
static const int kCPU = 1;
static const int kGPU = 2;
static const int kCPUPinned = 3;
static const int kCPUShared = 5;
static const int kAccBase = 10;

typedef int DeviceType;

static std::map<int, AccContext> acc_map;
static std::map<std::string, int> acc_names;

/*! \brief the device type we run the op on */
DeviceType dev_type;
/*! \brief device id we are going to run it on */
Expand All @@ -126,6 +144,7 @@ struct Context {
*/
inline int real_dev_id() const {
if (dev_type == kCPUPinned || dev_type == kGPU) return dev_id;
else if (dev_type >= kAccBase) return dev_id;
return 0;
}
/*!
Expand Down Expand Up @@ -169,7 +188,7 @@ struct Context {
return true;
}
/*! \brief the maximal device type */
static const int32_t kMaxDevType = 6;
static const int32_t kMaxDevType = 20;
/*! \brief the maximal device index */
static const int32_t kMaxDevID = 16;
/*!
Expand Down Expand Up @@ -222,6 +241,12 @@ struct Context {
* \return Context
*/
inline static Context FromString(const std::string& str);
/*!
* Load accelerator from given path and get its name
* \param path of .so file and name of the accelerator
* \return No return value
*/
inline static int LoadAcc(const std::string& path, char *name);
};

#if MXNET_USE_CUDA
Expand Down Expand Up @@ -487,6 +512,9 @@ inline Context Context::FromString(const std::string& str) {
ret = CPUPinned(id);
} else if (type == "cpu_shared") {
ret = CPUShared(id);
} else if (Context::acc_names.find(type) != Context::acc_names.end()) {
DeviceType dev_type = Context::acc_names[type];
ret = Create(dev_type, id);
} else {
LOG(FATAL) << "Invalid context string " << str;
}
Expand All @@ -496,6 +524,34 @@ inline Context Context::FromString(const std::string& str) {
return ret;
}

inline int Context::LoadAcc(const std::string& path, char *name) {
// load library
void *lib = load_lib(path.c_str());
if (!lib)
LOG(FATAL) << "Unable to load library";

// get name function from library
void (*getAccName)(char*);
get_func(lib, (void**)(&getAccName), const_cast<char*>("getAccName"));
if (!getAccName)
LOG(FATAL) << "Unable to get accelerator name from library";

// call name function
char accname[100];
getAccName(accname);
std::string name_str(accname);
snprintf(name, 100, "%s", name_str.c_str());

// create entry for accelerator
int id = Context::kAccBase + Context::acc_map.size();
AccContext ctx(id, name_str);

// add accelerator context to map
Context::acc_map[id] = ctx;
Context::acc_names[name_str] = id;
return id;
}

inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
if (ctx.dev_type == Context::kCPU) {
out << "cpu(";
Expand All @@ -505,6 +561,8 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
out << "cpu_pinned(";
} else if (ctx.dev_type == Context::kCPUShared) {
out << "cpu_shared(";
} else if (Context::acc_map.find(ctx.dev_type) != Context::acc_map.end()) {
out << Context::acc_map[ctx.dev_type].accName << "(";
} else {
out << "unknown(";
}
Expand Down
34 changes: 34 additions & 0 deletions include/mxnet/library.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file library.h
* \brief Defining accelerator loading functions
*/
#ifndef MXNET_LIBRARY_H_
#define MXNET_LIBRARY_H_

#include <dlfcn.h>
#include <iostream>

void* load_lib(const char* path);
void get_func(void* handle, void** func, char* name);

#endif // MXNET_LIBRARY_H_
28 changes: 28 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import threading
import warnings
import ctypes
import os
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass
from .base import _LIB
from .base import check_call
Expand Down Expand Up @@ -71,6 +72,7 @@ class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
_default_ctx = threading.local()
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
acc_map = {}
def __init__(self, device_type, device_id=0):
if isinstance(device_type, Context):
self.device_typeid = device_type.device_typeid
Expand Down Expand Up @@ -167,6 +169,32 @@ def empty_cache(self):
Context._default_ctx.value = Context('cpu', 0)


def load_acc(path):
#check if path exists
if not os.path.exists(path):
print('load_acc path "%s" does NOT exist' % path)
return None
#check if path is to a file
if not os.path.isfile(path):
print('load_acc path "%s" is NOT a library file' % path)
return None

byt_obj = path.encode('utf-8')
chararr = ctypes.c_char_p(byt_obj)
dev_id = ctypes.c_int()
name = ctypes.create_string_buffer(100)

check_call(_LIB.MXLoadAccLib(chararr,ctypes.byref(dev_id),name))

dev_id = dev_id.value
name = name.value
Context.devtype2str[dev_id] = name
Context.devstr2type[name] = dev_id
Context.acc_map[dev_id] = (name,path)

return Context(name, 0)


def cpu(device_id=0):
"""Returns a CPU context.
Expand Down
9 changes: 8 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,16 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2);
API_END();
}

// NOTE: return value is added in API_END

int MXLoadAccLib(const char *path, int *id, char *name) {
API_BEGIN();
std::string tmp(path);
int dev_id = mxnet::Context::LoadAcc(tmp, name);
*id = dev_id;
API_END();
}

int MXLibInfoFeatures(const struct LibFeature **lib_features, size_t *size) {
using namespace features;
API_BEGIN();
Expand Down
31 changes: 31 additions & 0 deletions src/common/base.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file base.cc
* \brief Defining map variables for the header file
*/
#include <string>
#include <map>

#include "../../include/mxnet/base.h"

std::map<int, mxnet::AccContext> mxnet::Context::acc_map;
std::map<std::string, int> mxnet::Context::acc_names;
47 changes: 47 additions & 0 deletions src/common/library.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file library.cc
* \brief Dynamically loading accelerator library
* and accessing its functions
*/

#include "../../include/mxnet/library.h"

void* load_lib(const char* path) {
void *handle;
handle = dlopen(path, RTLD_LAZY);

if (!handle) {
std::cerr << "Error loading accelerator library: '" << path
<< "'\n" << dlerror() << std::endl;
return 0;
}
return handle;
}

void get_func(void* handle, void** func, char* name) {
*reinterpret_cast<void**>(func) = dlsym(handle, name);
if (!func) {
std::cerr << "Error getting function '" << name
<< "' from accelerator library\n" << dlerror() << std::endl;
}
}

0 comments on commit 671df6f

Please sign in to comment.