Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Drop list support in optimize_for (#18483)
Browse files Browse the repository at this point in the history
* initial commit

* fixed typos

* changed warning to exception

* updated subgraph_op unittests
  • Loading branch information
samskalicky committed Jun 8, 2020
1 parent 2d58ff5 commit 028d01d
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 58 deletions.
12 changes: 6 additions & 6 deletions example/extensions/lib_pass/test_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,30 @@
sym = mx.sym.log(d)

def test_model(pass_name):
args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
# execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print(out)

# Symbol optimize_for
# with propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" with shapes/types' % pass_name)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
aux = []
mysym2 = sym.optimize_for(pass_name,arg_array,aux)
aux = {}
mysym2 = sym.optimize_for(pass_name,args,aux)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" without shapes/types' % pass_name)
mysym3 = sym.optimize_for(pass_name, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

Expand Down
25 changes: 12 additions & 13 deletions example/extensions/lib_subgraph/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,39 @@
sym2 = mx.sym.log(d2)

def test(backend):
args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
###############################################
# Test with subgraph not consuming params
###############################################
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print(out)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend,arg_array)
mysym2 = sym.optimize_for(backend,args)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# with propogating shapes/types, rejecting subgraph
print('-------------------------------')
print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend, arg_array, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
mysym2 = sym.optimize_for(backend, args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym3 = sym.optimize_for(backend, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

Expand All @@ -108,28 +107,28 @@ def test(backend):
###############################################
# Test with subgraph directly consuming params
###############################################
args = {'a':mx.nd.ones((3,2))}
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe5 = sym2.bind(ctx=mx.cpu(), args=args)
out5 = exe5.forward()
print(out5)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32')]
mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True)
mysym6 = sym2.optimize_for(backend, args, reqArgs=True)
print(mysym6.tojson())
exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe6 = mysym6.bind(ctx=mx.cpu(), args=args)
out6 = exe6.forward()
print(out6)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym7 = sym2.optimize_for(backend, reqArgs=True)
exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe7 = mysym7.bind(ctx=mx.cpu(), args=args)
out7 = exe7.forward()
print(out7)

Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,12 +1033,12 @@ def _build_cache(self, *args):
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()]
aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()]
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
# Partition the graph.
out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts)
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
#update cached graph with partitioned graph
self._cached_graph = data, out
self._cached_op = ndarray.CachedOp(out, flags)
Expand Down
40 changes: 16 additions & 24 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,17 +1456,15 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`
args : list of NDArray or dict of str to NDArray, optional
args : dict of str to NDArray, optional
Input arguments to the symbol, required to infer shapes/types before partitioning
- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
aux : list of NDArray or dict of str to NDArray, optional
aux : dict of str to NDArray, optional
Input auxiliary arguments to the symbol
- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.
Expand All @@ -1483,6 +1481,8 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
"""
out = SymbolHandle()
assert isinstance(backend, str)
assert isinstance(args, dict) or args is None
assert isinstance(aux, dict) or aux is None

if args is None or len(args) == 0:
args_ = []
Expand Down Expand Up @@ -1530,30 +1530,22 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
ctypes.byref(new_aux_size),
ctypes.byref(new_aux_handle),
ctypes.byref(new_aux_names)))
arg_names = self.list_arguments()
if isinstance(args, dict):
# add new args/aux
if not args is None:
for i in range(new_args_size.value):
args[py_str(new_arg_names[i])] = NDArray(NDArrayHandle(new_args_handle[i]))
elif isinstance(args, list):
for i in range(new_args_size.value):
name = py_str(new_arg_names[i])
if name in arg_names:
idx = arg_names.index(name)
args[idx] = NDArray(NDArrayHandle(new_args_handle[i]))
else:
args.append(NDArray(NDArrayHandle(new_args_handle[i])))
aux_names = self.list_auxiliary_states()
if isinstance(aux, dict):
elif new_args_size.value > 0:
raise RuntimeError('Cannot add new args in optimize_for since args is None\n' +
'Provide a dictionary to the args argument to optimize_for')

if not aux is None:
for i in range(new_aux_size.value):
aux[py_str(new_aux_names[i])] = NDArray(NDArrayHandle(new_aux_handle[i]))
elif isinstance(aux, list):
for i in range(new_aux_size.value):
name = py_str(new_aux_names[i])
if name in aux_names:
idx = aux_names.index(name)
aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i]))
else:
aux.append(NDArray(NDArrayHandle(new_aux_handle[i])))
elif new_aux_size.value > 0:
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')

# return modified symbol
return Symbol(out)


Expand Down
6 changes: 2 additions & 4 deletions tests/python/unittest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ def test_subgraph():
sym = mx.sym.log(d)

args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 'b':mx.nd.ones((3,2),ctx=mx.cpu())}
arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()),
mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())]

# baseline - regular execution in MXNet
exe = sym.bind(ctx=mx.cpu(), args=args)
Expand All @@ -147,14 +145,14 @@ def test_subgraph():

# with propogating shapes/types, rejecting subgraph
# this tests creating the subgraph and having the subgraph prop reject it
mysym2 = sym.optimize_for("myProp", arg_array, reject=True)
mysym2 = sym.optimize_for("myProp", args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3)

# with propogating shapes/types
mysym3 = sym.optimize_for("myProp",arg_array)
mysym3 = sym.optimize_for("myProp",args)
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
# check that result matches one executed by MXNet
Expand Down
14 changes: 8 additions & 6 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,20 @@ def test_subgraph_exe8(sym, subgraph_backend, op_names):
# bind
sym, _, _ = sym
arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
arg_names = sym.list_arguments()
aux_names = sym.list_auxiliary_states()
arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)}
aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)}
exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
exe1.forward()

# infer shape/type before partition before bind
check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array)
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict)
check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))

exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
exe2.forward()

# compare outputs
Expand Down

0 comments on commit 028d01d

Please sign in to comment.