diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3c27323ee758..2760aa0f5363 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8660,11 +8660,18 @@ def test_get_all_registered_operators(): ops = get_all_registered_operators() ok_(isinstance(ops, list)) ok_(len(ops) > 0) + ok_('Activation' in ops) def test_get_operator_arguments(): - operator_arguments = get_operator_arguments(mx.operator.get_all_registered_operators()[0]) + operator_arguments = get_operator_arguments('Activation') ok_(isinstance(operator_arguments, OperatorArguments)) + ok_(operator_arguments.names == ['data', 'act_type']) + print(operator_arguments.types) + ok_(operator_arguments.types + == ['NDArray-or-Symbol', "{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]) + ok_(operator_arguments.narg == 2) + if __name__ == '__main__':