Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Expose get_all_registered_operators and get_operator_arguments in the…
Browse files Browse the repository at this point in the history
… Python API.
  • Loading branch information
larroy committed Jun 25, 2019
1 parent 7fe478a commit 6d528f7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 77 deletions.
81 changes: 6 additions & 75 deletions benchmark/opperf/utils/op_registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ctypes
import sys
from mxnet.base import _LIB, check_call, py_str, OpHandle, c_str, mx_uint
import mxnet as mx

from benchmark.opperf.rules.default_params import DEFAULTS_INPUTS

Expand Down Expand Up @@ -77,89 +78,19 @@ def _select_ops(operator_names, filters=("_contrib", "_"), merge_op_forward_back
return mx_operators


def _get_all_registered_ops():
"""Get all registered MXNet operator names.
Returns
-------
["operator_name"]
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))

mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)]
return mx_registered_operator_names


def _get_op_handles(op_name):
"""Get handle for an operator with given name - op_name.
Parameters
----------
op_name: str
Name of operator to get handle for.
"""
op_handle = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle)))
return op_handle


def _get_op_arguments(op_handle):
"""Given operator name and handle, fetch operator arguments - number of arguments,
argument names, argument types.
Parameters
----------
op_handle: OpHandle
Handle for the operator
Returns
-------
(narg, arg_names, arg_types)
"""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
key_var_num_args = ctypes.c_char_p()
ret_type = ctypes.c_char_p()

check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
op_handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(key_var_num_args),
ctypes.byref(ret_type)))

narg = int(num_args.value)
arg_names = [py_str(arg_names[i]) for i in range(narg)]
arg_types = [py_str(arg_types[i]) for i in range(narg)]

return narg, arg_names, arg_types


def _set_op_arguments(mx_operators):
"""Fetch and set operator arguments - nargs, arg_names, arg_types
"""
for op_name in mx_operators:
op_handle = _get_op_handles(op_name)
narg, arg_names, arg_types = _get_op_arguments(op_handle)
mx_operators[op_name]["params"] = {"narg": narg,
"arg_names": arg_names,
"arg_types": arg_types}
operator_arguments = mx.operator.get_operator_arguments(op_name)
mx_operators[op_name]["params"] = {"narg": operator_arguments.narg,
"arg_names": operator_arguments.names,
"arg_types": operator_arguments.types}


def _get_all_mxnet_operators():
# Step 1 - Get all registered op names and filter it
operator_names = _get_all_registered_ops()
operator_names = mx.operator.get_all_registered_operators()
mx_operators = _select_ops(operator_names)

# Step 2 - Get all parameters for the operators
Expand Down
61 changes: 60 additions & 1 deletion python/mxnet/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@

import traceback
import warnings
import collections

from array import array
from threading import Lock
import ctypes
from ctypes import CFUNCTYPE, POINTER, Structure, pointer
from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool

from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int
from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int, OpHandle
from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
from . import symbol, context
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
Expand Down Expand Up @@ -1099,3 +1101,60 @@ def delete_entry(_):
return do_register

register("custom_op")(CustomOpProp)


def get_all_registered_operators():
"""Get all registered MXNet operator names.
Returns
-------
operator_names : list of string
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))

mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)]
return mx_registered_operator_names

OperatorArguments = collections.namedtuple('OperatorArguments', ['narg', 'names', 'types'])

def get_operator_arguments(op_name):
"""Given operator name, fetch operator arguments - number of arguments,
argument names, argument types.
Parameters
----------
op_name: str
Handle for the operator
Returns
-------
operator_arguments : OperatorArguments, namedtuple with number of arguments, names and types
"""
op_handle = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle)))
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = mx_uint()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
key_var_num_args = ctypes.c_char_p()
ret_type = ctypes.c_char_p()

check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
op_handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(key_var_num_args),
ctypes.byref(ret_type)))

narg = int(num_args.value)
arg_names = [py_str(arg_names[i]) for i in range(narg)]
arg_types = [py_str(arg_types[i]) for i in range(narg)]
return OperatorArguments(narg, arg_names, arg_types)
14 changes: 13 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from distutils.version import LooseVersion
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.operator import *
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
from common import run_in_spawned_process
from nose.tools import assert_raises
from nose.tools import assert_raises, ok_
import unittest
import os

Expand Down Expand Up @@ -8655,6 +8656,17 @@ def test_add_n():
assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5)


def test_get_all_registered_operators():
ops = get_all_registered_operators()
ok_(isinstance(ops, list))
ok_(len(ops) > 0)


def test_get_operator_arguments():
operator_arguments = get_operator_arguments(mx.operator.get_all_registered_operators()[0])
ok_(isinstance(operator_arguments, OperatorArguments))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 6d528f7

Please sign in to comment.