From 051c184ab1ad6ea86fbf4dcbec920fd2b64ef7ba Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 31 May 2018 14:44:38 -0700 Subject: [PATCH] Fix unit test --- tests/python/unittest/test_subgraph_op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index bffc55c1fa6c..1eff0a0cfb7b 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -1,6 +1,6 @@ import ctypes import mxnet as mx -from mxnet.base import SymbolHandle, check_call, _LIB +from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array from mxnet.symbol import Symbol import numpy as np @@ -14,7 +14,11 @@ def test_subgraph_op_whole_graph(): out = SymbolHandle() - check_call(_LIB.MXPartitionGraph(regular_sym.handle, ctypes.byref(out))) + op_names = [] + #op_names = [mx.sym.sin.__name__, mx.sym.Convolution.__name__] + + check_call(_LIB.MXPartitionGraph(regular_sym.handle, mx_uint(len(op_names)), + c_str_array(op_names), ctypes.byref(out))) subgraph_sym = Symbol(out) assert regular_sym.list_inputs() == subgraph_sym.list_inputs()