diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py index 5da0e322eccc..6509be37f39d 100644 --- a/benchmark/opperf/utils/op_registry_utils.py +++ b/benchmark/opperf/utils/op_registry_utils.py @@ -20,6 +20,7 @@ from operator import itemgetter from mxnet import runtime from mxnet.base import _LIB, check_call, py_str, OpHandle, c_str, mx_uint +import mxnet as mx from benchmark.opperf.rules.default_params import DEFAULTS_INPUTS, MX_OP_MODULE @@ -75,89 +76,19 @@ def _select_ops(operator_names, filters=("_contrib", "_"), merge_op_forward_back return mx_operators -def _get_all_registered_ops(): - """Get all registered MXNet operator names. - - - Returns - ------- - ["operator_name"] - """ - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - - mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)] - return mx_registered_operator_names - - -def _get_op_handles(op_name): - """Get handle for an operator with given name - op_name. - - Parameters - ---------- - op_name: str - Name of operator to get handle for. - """ - op_handle = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle))) - return op_handle - - -def _get_op_arguments(op_handle): - """Given operator name and handle, fetch operator arguments - number of arguments, - argument names, argument types. - - Parameters - ---------- - op_handle: OpHandle - Handle for the operator - - Returns - ------- - (narg, arg_names, arg_types) - """ - real_name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = mx_uint() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - key_var_num_args = ctypes.c_char_p() - ret_type = ctypes.c_char_p() - - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( - op_handle, ctypes.byref(real_name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(key_var_num_args), - ctypes.byref(ret_type))) - - narg = int(num_args.value) - arg_names = [py_str(arg_names[i]) for i in range(narg)] - arg_types = [py_str(arg_types[i]) for i in range(narg)] - - return narg, arg_names, arg_types - - def _set_op_arguments(mx_operators): """Fetch and set operator arguments - nargs, arg_names, arg_types """ for op_name in mx_operators: - op_handle = _get_op_handles(op_name) - narg, arg_names, arg_types = _get_op_arguments(op_handle) - mx_operators[op_name]["params"] = {"narg": narg, - "arg_names": arg_names, - "arg_types": arg_types} + operator_arguments = mx.operator.get_operator_arguments(op_name) + mx_operators[op_name]["params"] = {"narg": operator_arguments.narg, + "arg_names": operator_arguments.names, + "arg_types": operator_arguments.types} def _get_all_mxnet_operators(): # Step 1 - Get all registered op names and filter it - operator_names = _get_all_registered_ops() + operator_names = mx.operator.get_all_registered_operators() mx_operators = _select_ops(operator_names) # Step 2 - Get all parameters for the operators diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index 33e9b89a032c..8d14a805ef56 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -22,13 +22,15 @@ import traceback import warnings +import collections from array import array from threading import Lock +import ctypes from ctypes import CFUNCTYPE, POINTER, Structure, pointer from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool -from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int +from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int, OpHandle from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str from . import symbol, context from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP @@ -1099,3 +1101,60 @@ def delete_entry(_): return do_register register("custom_op")(CustomOpProp) + + +def get_all_registered_operators(): + """Get all registered MXNet operator names. + + Returns + ------- + operator_names : list of string + """ + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + + mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)] + return mx_registered_operator_names + +OperatorArguments = collections.namedtuple('OperatorArguments', ['narg', 'names', 'types']) + +def get_operator_arguments(op_name): + """Given operator name, fetch operator arguments - number of arguments, + argument names, argument types. + + Parameters + ---------- + op_name: str + Handle for the operator + + Returns + ------- + operator_arguments : OperatorArguments, namedtuple with number of arguments, names and types + """ + op_handle = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle))) + real_name = ctypes.c_char_p() + desc = ctypes.c_char_p() + num_args = mx_uint() + arg_names = ctypes.POINTER(ctypes.c_char_p)() + arg_types = ctypes.POINTER(ctypes.c_char_p)() + arg_descs = ctypes.POINTER(ctypes.c_char_p)() + key_var_num_args = ctypes.c_char_p() + ret_type = ctypes.c_char_p() + + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( + op_handle, ctypes.byref(real_name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs), + ctypes.byref(key_var_num_args), + ctypes.byref(ret_type))) + + narg = int(num_args.value) + arg_names = [py_str(arg_names[i]) for i in range(narg)] + arg_types = [py_str(arg_types[i]) for i in range(narg)] + return OperatorArguments(narg, arg_names, arg_types) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d50304712196..ca9ecc45621b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -27,10 +27,11 @@ from distutils.version import LooseVersion from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * +from mxnet.operator import * from mxnet.base import py_str, MXNetError, _as_list from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises from common import run_in_spawned_process -from nose.tools import assert_raises +from nose.tools import assert_raises, ok_ import unittest import os @@ -8655,6 +8656,22 @@ def test_add_n(): assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5) +def test_get_all_registered_operators(): + ops = get_all_registered_operators() + ok_(isinstance(ops, list)) + ok_(len(ops) > 0) + ok_('Activation' in ops) + + +def test_get_operator_arguments(): + operator_arguments = get_operator_arguments('Activation') + ok_(isinstance(operator_arguments, OperatorArguments)) + ok_(operator_arguments.names == ['data', 'act_type']) + ok_(operator_arguments.types + == ['NDArray-or-Symbol', "{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]) + ok_(operator_arguments.narg == 2) + + if __name__ == '__main__': import nose nose.runmodule()