Skip to content

Commit

Permalink
Addresses comments in runtime feature discovery API (apache#13964)
Browse files Browse the repository at this point in the history
* Prototype for runtime feature detection

* Includes from diamond to quotes

* Add CPU feature and BLAS flavour flags

* Add BLAS flavour and CPU SSE and AVX flags

* MXNET_USE_LAPACK

* Fix C++ linting errors

* Expose runtime feature detection in the public C API and in the Python API

* Refactor Storage -> FeatureSet

* Refine documentation

* Add failure case

* Fix pylint

* Address CR comments

* Address CR comments

* Address CR

* Address CR

* Address CR

* Address CR

* remove old files

* Fix unit test

* Port CMake blas change from apache#13957

* Fix lint

* mxruntime -> libinfo

* Fix comments

* restore libinfo.py

* Rework API for feature detection / libinfo

* Refine documentation

* Fix lint

* Fix lint

* Define make_unique only for C++ std < 14

* Add memory include

* remove old tests

* make_unique fiasco

* Fix lint
  • Loading branch information
larroy authored and vdantu committed Mar 31, 2019
1 parent d1305c4 commit 77cb1c2
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 174 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ else()
add_definitions(-DMXNET_USE_NCCL=0)
endif()

include(cmake/ChooseBlas.cmake)
if(USE_CUDA AND FIRST_CUDA)
include(cmake/ChooseBlas.cmake)
include(3rdparty/mshadow/cmake/Utils.cmake)
include(cmake/FirstClassLangCuda.cmake)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
Expand Down
9 changes: 8 additions & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#include "nnvm/op.h"
#include "nnvm/tuple.h"
#include "nnvm/symbolic.h"
#include "mxfeatures.h"
#include "libinfo.h"


/*!
Expand Down Expand Up @@ -403,7 +403,14 @@ template<> struct hash<mxnet::Context> {
return res;
}
};

#if __cplusplus < 201402L && !defined(_MSC_VER)
template<typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
#endif
} // namespace std

#include "./tensor_blob.h"
//! \endcond
Expand Down
14 changes: 10 additions & 4 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ struct MXCallbackList {
void **contexts;
};

struct LibFeature {
const char* name;
uint32_t index;
bool enabled;
};

enum CustomOpCallbacks {
kCustomOpDelete,
kCustomOpForward,
Expand Down Expand Up @@ -210,12 +216,12 @@ MXNET_DLL const char *MXGetLastError();
//-------------------------------------

/*!
* \brief
* \param feature to check mxfeatures.h
* \param out set to true if the feature is enabled, false otherwise
* \brief Get list of features supported on the runtime
* \param libFeature pointer to array of LibFeature
* \param size of the array
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXHasFeature(const mx_uint feature, bool* out);
MXNET_DLL int MXLibInfoFeatures(const struct LibFeature **libFeature, size_t *size);

/*!
* \brief Seed all global random number generators in mxnet.
Expand Down
34 changes: 28 additions & 6 deletions include/mxnet/mxfeatures.h → include/mxnet/libinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,27 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \file mxfeatures.h
* \brief check MXNet features including compile time support
* Copyright (c) 2018 by Contributors
* \file libinfo.h
* \author larroy
* \brief get features of the MXNet library at runtime
*/

#pragma once

#include <string>
#include <vector>
#include <array>
#include <memory>
#include "dmlc/base.h"
#include "mshadow/base.h"
#include "c_api.h"

/*!
*\brief whether to use opencv support
*/
#ifndef MXNET_USE_OPENCV
#define MXNET_USE_OPENCV 1
#define MXNET_USE_OPENCV 0
#endif

/*!
Expand Down Expand Up @@ -124,7 +130,8 @@ namespace features {
// Check compile flags such as CMakeLists.txt

/// Compile time features
enum : uint32_t {
// ATTENTION: When changing this enum, match the strings in the implementation file!
enum : unsigned {
// NVIDIA, CUDA
CUDA = 0,
CUDNN,
Expand Down Expand Up @@ -179,10 +186,25 @@ enum : uint32_t {
};


struct EnumNames {
static const std::vector<std::string> names;
};

struct LibInfo {
LibInfo();
static LibInfo* getInstance();
const std::array<LibFeature, MAX_FEATURES>& getFeatures() {
return m_lib_features;
}
private:
std::array<LibFeature, MAX_FEATURES> m_lib_features;
static std::unique_ptr<LibInfo> m_inst;
};

/*!
* \return true if the given feature is supported
*/
bool is_enabled(uint32_t feat);
bool is_enabled(unsigned feat);

} // namespace features
} // namespace mxnet
103 changes: 0 additions & 103 deletions python/mxnet/mxfeatures.py

This file was deleted.

48 changes: 48 additions & 0 deletions python/mxnet/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=not-an-iterable

"""runtime querying of compile time features in the native library"""

import ctypes
from .base import _LIB, check_call

class LibFeature(ctypes.Structure):
"""
Compile time feature description
"""
_fields_ = [
("name", ctypes.c_char_p),
("index", ctypes.c_uint32),
("enabled", ctypes.c_bool)
]

def libinfo_features():
"""
Check the library for compile-time features. The list of features are maintained in libinfo.h and libinfo.cc
Returns
-------
A list of class LibFeature indicating which features are available and enabled
"""
lib_features = ctypes.POINTER(LibFeature)()
lib_features_size = ctypes.c_size_t()
check_call(_LIB.MXLibInfoFeatures(ctypes.byref(lib_features), ctypes.byref(lib_features_size)))
feature_list = [lib_features[i] for i in range(lib_features_size.value)]
return feature_list
9 changes: 6 additions & 3 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#include "mxnet/kvstore.h"
#include "mxnet/rtc.h"
#include "mxnet/storage.h"
#include "mxnet/mxfeatures.h"
#include "mxnet/libinfo.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
Expand Down Expand Up @@ -87,9 +87,12 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,

// NOTE: return value is added in API_END

int MXHasFeature(const mx_uint feature, bool* out) {
int MXLibInfoFeatures(const struct LibFeature **lib_features, size_t *size) {
using namespace features;
API_BEGIN();
*out = features::is_enabled(feature);
LibInfo* lib_info = LibInfo::getInstance();
*lib_features = lib_info->getFeatures().data();
*size = lib_info->getFeatures().size();
API_END();
}

Expand Down
10 changes: 3 additions & 7 deletions src/c_api/c_api_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ struct APICallTimingData {
#endif // PROFILE_API_INCLUDE_AS_EVENT
};

template<typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

/*!
* \brief Per-thread profiling data
*/
Expand All @@ -78,7 +73,7 @@ class ProfilingThreadData {
auto iter = tasks_.find(name);
if (iter == tasks_.end()) {
iter = tasks_.emplace(std::make_pair(
name, make_unique<profiler::ProfileTask>(name, domain))).first;
name, std::make_unique<profiler::ProfileTask>(name, domain))).first;
}
return iter->second.get();
}
Expand All @@ -93,7 +88,8 @@ class ProfilingThreadData {
// Per-thread so no lock necessary
auto iter = events_.find(name);
if (iter == events_.end()) {
iter = events_.emplace(std::make_pair(name, make_unique<profiler::ProfileEvent>(name))).first;
iter = events_.emplace(std::make_pair(name,
std::make_unique<profiler::ProfileEvent>(name))).first;
}
return iter->second.get();
}
Expand Down
Loading

0 comments on commit 77cb1c2

Please sign in to comment.