Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Expose runtime feature detection in the public C API and in the Pytho…
Browse files Browse the repository at this point in the history
…n API
  • Loading branch information
larroy committed Jan 9, 2019
1 parent 82738cd commit 93529d3
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 7 deletions.
10 changes: 10 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ MXNET_DLL const char *MXGetLastError();
//-------------------------------------
// Part 0: Global State setups
//-------------------------------------

/*!
* \brief
* \param feature to check mxfeatures.h
* \param out set to true if the feature is enabled, false otherwise
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXHasFeature(const mx_uint feature, bool* out);

/*!
* \brief Seed all global random number generators in mxnet.
* \param seed the random number seed.
Expand Down Expand Up @@ -465,6 +474,7 @@ MXNET_DLL int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t
*/
MXNET_DLL int MXGetVersion(int *out);


//-------------------------------------
// Part 1: NDArray creation and deletion
//-------------------------------------
Expand Down
91 changes: 91 additions & 0 deletions python/mxnet/mxfeatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument

"""runtime detection of compile time features in the native library"""

import ctypes
import enum
from .base import _LIB, check_call, mx_uint

feature_names = [
"CUDA",
"CUDNN",
"NCCL",
"CUDA_RTC",
"TENSORRT",
"CPU_SSE",
"CPU_SSE2",
"CPU_SSE3",
"CPU_SSE4_1",
"CPU_SSE4_2",
"CPU_SSE4A",
"CPU_AVX",
"CPU_AVX2",
"OPENMP",
"SSE",
"F16C",
"JEMALLOC",
"BLAS_OPEN",
"BLAS_ATLAS",
"BLAS_MKL",
"BLAS_APPLE",
"LAPACK",
"MKLDNN",
"OPENCV",
"CAFFE",
"PROFILER",
"DIST_KVSTORE",
"CXX14",
"SIGNAL_HANDLER",
"DEBUG"
]


Feature = enum.Enum('Feature', {name: index for index, name in enumerate(feature_names)})


def has_feature(feature):
"""
Check the library for compile-time feature at runtime
Parameters
----------
feature : int
An integer representing the feature to check
Returns
-------
boolean
True if the feature is enabled, false otherwise
"""
res = ctypes.c_bool()
check_call(_LIB.MXHasFeature(mx_uint(feature), ctypes.byref(res)))
return res.value


def features_enabled():
res = []
for f in Feature:
if has_feature(f.value):
res.append(f)
return res

def features_enabled_str(sep=', '):
return sep.join(map(lambda x: x.name, features_enabled()))
22 changes: 15 additions & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
* \file c_api.cc
* \brief C API of mxnet
*/
#include <vector>
#include <sstream>
#include <string>
#include <mutex>
#include <memory>
#include <functional>
#include <utility>
#include "dmlc/base.h"
#include "dmlc/logging.h"
#include "dmlc/io.h"
Expand All @@ -36,13 +43,7 @@
#include "mxnet/kvstore.h"
#include "mxnet/rtc.h"
#include "mxnet/storage.h"
#include <vector>
#include <sstream>
#include <string>
#include <mutex>
#include <memory>
#include <functional>
#include <utility>
#include "mxnet/mxfeatures.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
Expand Down Expand Up @@ -85,6 +86,13 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
}

// NOTE: return value is added in API_END

int MXHasFeature(const mx_uint feature, bool* out) {
API_BEGIN();
*out = features::is_enabled(feature);
API_END();
}

int MXRandomSeed(int seed) {
API_BEGIN();
mxnet::RandomSeed(seed);
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet.mxfeatures import *
from nose.tools import *

def test_runtime_features():
for f in Feature:
res = has_feature(f.value)
ok_(type(res) is bool)
for f in features_enabled():
ok_(type(f) is Feature)
ok_(type(features_enabled_str()) is str)
print("Features enabled: {}".format(features_enabled_str()))


if __name__ == "__main__":
import nose
nose.runmodule()

0 comments on commit 93529d3

Please sign in to comment.