diff --git a/example/extensions/lib_pass/test_pass.py b/example/extensions/lib_pass/test_pass.py index 8930c9478152..5d4578391097 100644 --- a/example/extensions/lib_pass/test_pass.py +++ b/example/extensions/lib_pass/test_pass.py @@ -48,10 +48,11 @@ sym = mx.sym.log(d) def test_model(pass_name): + args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))} # execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe = sym.bind(ctx=mx.cpu(), args=args) out = exe.forward() print(out) @@ -59,11 +60,10 @@ def test_model(pass_name): # with propogating shapes/types print('-------------------------------') print('Testing pass "%s" with shapes/types' % pass_name) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - aux = [] - mysym2 = sym.optimize_for(pass_name,arg_array,aux) + aux = {} + mysym2 = sym.optimize_for(pass_name,args,aux) print(mysym2.tojson()) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) @@ -71,7 +71,7 @@ def test_model(pass_name): print('-------------------------------') print('Testing pass "%s" without shapes/types' % pass_name) mysym3 = sym.optimize_for(pass_name, myOpt='yello') - exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() print(out3) diff --git a/example/extensions/lib_subgraph/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py index eb7102a1511c..fa56b50515e5 100644 --- a/example/extensions/lib_subgraph/test_subgraph.py +++ b/example/extensions/lib_subgraph/test_subgraph.py @@ -49,32 +49,31 @@ sym2 = mx.sym.log(d2) def test(backend): + args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))} ############################################### # Test with subgraph not consuming params ############################################### #execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe = sym.bind(ctx=mx.cpu(), args=args) out = exe.forward() print(out) # with propogating shapes/types print('-------------------------------') print('Testing %s partitioning with shapes/types' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - mysym2 = sym.optimize_for(backend,arg_array) + mysym2 = sym.optimize_for(backend,args) print(mysym2.tojson()) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) # with propogating shapes/types, rejecting subgraph print('-------------------------------') print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] - mysym2 = sym.optimize_for(backend, arg_array, reject=True) - exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + mysym2 = sym.optimize_for(backend, args, reject=True) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() print(out2) @@ -82,7 +81,7 @@ def test(backend): print('-------------------------------') print('Testing %s partitioning without shapes/types' % backend) mysym3 = sym.optimize_for(backend, myOpt='yello') - exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) + exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() print(out3) @@ -108,20 +107,20 @@ def test(backend): ############################################### # Test with subgraph directly consuming params ############################################### + args = {'a':mx.nd.ones((3,2))} #execute in MXNet print('-------------------------------') print('Testing regular MXNet execution') - exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe5 = sym2.bind(ctx=mx.cpu(), args=args) out5 = exe5.forward() print(out5) # with propogating shapes/types print('-------------------------------') print('Testing %s partitioning with shapes/types' % backend) - arg_array = [mx.nd.ones((3,2),dtype='float32')] - mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True) + mysym6 = sym2.optimize_for(backend, args, reqArgs=True) print(mysym6.tojson()) - exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe6 = mysym6.bind(ctx=mx.cpu(), args=args) out6 = exe6.forward() print(out6) @@ -129,7 +128,7 @@ def test(backend): print('-------------------------------') print('Testing %s partitioning without shapes/types' % backend) mysym7 = sym2.optimize_for(backend, reqArgs=True) - exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))}) + exe7 = mysym7.bind(ctx=mx.cpu(), args=args) out7 = exe7.forward() print(out7) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 78b5a7262121..6c3612c7d784 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1033,12 +1033,12 @@ def _build_cache(self, *args): if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()] - aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()] + arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()} + aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()} # Partition the graph. - out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts) + out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) #update cached graph with partitioned graph self._cached_graph = data, out self._cached_op = ndarray.CachedOp(out, flags) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index d4ff7954c181..5fb1b8c4e700 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1456,17 +1456,15 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): backend : str The name of backend, as registered in `SubgraphBackendRegistry` - args : list of NDArray or dict of str to NDArray, optional + args : dict of str to NDArray, optional Input arguments to the symbol, required to infer shapes/types before partitioning - - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. - If type is a dict of str to `NDArray`, then it maps the name of arguments to the corresponding `NDArray`. - aux : list of NDArray or dict of str to NDArray, optional + aux : dict of str to NDArray, optional Input auxiliary arguments to the symbol - - If type is a list of `NDArray`, the order is the same as that of `list_arguments()`. - If type is a dict of str to `NDArray`, then it maps the name of arguments to the corresponding `NDArray`. @@ -1483,6 +1481,8 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): """ out = SymbolHandle() assert isinstance(backend, str) + assert isinstance(args, dict) or args is None + assert isinstance(aux, dict) or aux is None if args is None or len(args) == 0: args_ = [] @@ -1530,30 +1530,22 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): ctypes.byref(new_aux_size), ctypes.byref(new_aux_handle), ctypes.byref(new_aux_names))) - arg_names = self.list_arguments() - if isinstance(args, dict): + # add new args/aux + if not args is None: for i in range(new_args_size.value): args[py_str(new_arg_names[i])] = NDArray(NDArrayHandle(new_args_handle[i])) - elif isinstance(args, list): - for i in range(new_args_size.value): - name = py_str(new_arg_names[i]) - if name in arg_names: - idx = arg_names.index(name) - args[idx] = NDArray(NDArrayHandle(new_args_handle[i])) - else: - args.append(NDArray(NDArrayHandle(new_args_handle[i]))) - aux_names = self.list_auxiliary_states() - if isinstance(aux, dict): + elif new_args_size.value > 0: + raise RuntimeError('Cannot add new args in optimize_for since args is None\n' + + 'Provide a dictionary to the args argument to optimize_for') + + if not aux is None: for i in range(new_aux_size.value): aux[py_str(new_aux_names[i])] = NDArray(NDArrayHandle(new_aux_handle[i])) - elif isinstance(aux, list): - for i in range(new_aux_size.value): - name = py_str(new_aux_names[i]) - if name in aux_names: - idx = aux_names.index(name) - aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i])) - else: - aux.append(NDArray(NDArrayHandle(new_aux_handle[i]))) + elif new_aux_size.value > 0: + raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + + 'Provide a dictionary to the aux argument to optimize_for') + + # return modified symbol return Symbol(out) diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py index d60ca0d352a2..2e21f927737b 100644 --- a/tests/python/unittest/test_extensions.py +++ b/tests/python/unittest/test_extensions.py @@ -130,8 +130,6 @@ def test_subgraph(): sym = mx.sym.log(d) args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 'b':mx.nd.ones((3,2),ctx=mx.cpu())} - arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()), - mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())] # baseline - regular execution in MXNet exe = sym.bind(ctx=mx.cpu(), args=args) @@ -147,14 +145,14 @@ def test_subgraph(): # with propogating shapes/types, rejecting subgraph # this tests creating the subgraph and having the subgraph prop reject it - mysym2 = sym.optimize_for("myProp", arg_array, reject=True) + mysym2 = sym.optimize_for("myProp", args, reject=True) exe2 = mysym2.bind(ctx=mx.cpu(), args=args) out2 = exe2.forward() # check that result matches one executed by MXNet assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3) # with propogating shapes/types - mysym3 = sym.optimize_for("myProp",arg_array) + mysym3 = sym.optimize_for("myProp",args) exe3 = mysym3.bind(ctx=mx.cpu(), args=args) out3 = exe3.forward() # check that result matches one executed by MXNet diff --git a/tests/python/unittest/test_subgraph_op.py b/tests/python/unittest/test_subgraph_op.py index f5b96d84aefe..d27295e584e0 100644 --- a/tests/python/unittest/test_subgraph_op.py +++ b/tests/python/unittest/test_subgraph_op.py @@ -353,18 +353,20 @@ def test_subgraph_exe8(sym, subgraph_backend, op_names): # bind sym, _, _ = sym arg_shapes, _, aux_shapes = sym.infer_shape() - arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes] - aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] - exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + arg_names = sym.list_arguments() + aux_names = sym.list_auxiliary_states() + arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)} + aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)} + exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') exe1.forward() # infer shape/type before partition before bind check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), - c_str_array(op_names))) - part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array) + c_str_array(op_names))) + part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict) check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) - exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') + exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') exe2.forward() # compare outputs