Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add while_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 5, 2018
1 parent 19ac41d commit 6976b90
Show file tree
Hide file tree
Showing 7 changed files with 1,413 additions and 12 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 6ab4da to 290226
127 changes: 127 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,130 @@ def check_input(inputs, in_type, msg):
if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)


def while_loop(loop_vars, cond, func, max_iterations):
"""Run a while loop with user-defined computation and loop condition.
This operator simulates a while loop which iterately does customized computation
as long as the condition is satisfied.
`loop_vars` is a list of NDArrays on which the computation uses.
`cond` is a user-defined function as the loop condition.
It consumes `loop_vars`, and produces a scalar MXNet NDArray,
indicating the termination of the loop.
The loop ends when `cond` returns false (zero).
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => NDArray`.
`func` is a user-defined function as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
The number of elements, shape, dtype of each element in `step_output` should be consistent.
The `new_loop_vars` should be consistent with `loop_vars` on each step.
The `func` is variadic, and its signature should be
`cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`.
`max_iterations` is a scalar that defines the maximum number of iterations allowed.
This function returns a list of NDArrays of length `|step_output| + |loop_vars|`.
The i-th element in the first `|step_output|` ones of the list represent
the i-th `step_output` at all step, stacked along axis 0.
The i-th element in the last `|loop_vars|` ones of the list
represent the final state of each loop variable.
Warning: when `cond` is never satisfied, we assume `step_output` is empty.
TODO(Junru): the output shape along axis 0 is not consistent to the symbloic version.
Should we mention this in our doc?
Parameters
----------
loop_vars: list of NDArrays.
The initial values of the loop variables.
cond: a Python function.
The loop condition.
func: a Python function.
The loop body.
max_iteration: a python int.
Maximum number of iterations.
Returns
-------
outputs: a list of NDArrays of length `|step_output| + |loop_vars|`.
The first `|step_output|` NDArrays are outputs.
The last `|loop_vars|` NDArrays are the final state of loop variables.
TODO(Junru): change the output format
Examples
--------
TODO(Junru): run this
>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: (i + 1, s + i)
>>> loop_vars = (mx.nd.array([1], dtype="int64"), mx.nd.array([0], dtype="int64"))
>>> outputs = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10)
"""
def _to_python_scalar(inputs, type, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if isinstance(inputs, ndarray.NDArray):
inputs = inputs.asscalar()
try:
inputs = type(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type.__name__))
return inputs

def _to_ndarray_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray,
a tuple of mxnet NDArray, into a tuple of NDArray
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, ndarray.NDArray):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
for item in inputs:
if not isinstance(item, ndarray.NDArray):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
return inputs

def _func_wrapper(loop_vars):
"""This wrapper unifies
"func: loop_vars -> new_loop_vars"
and "func: loop_vars -> (step_output, new_loop_vars)"
into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars)
"""
step_output, new_loop_vars = func(*loop_vars)
if step_output is None:
step_output = []
if new_loop_vars is None:
new_loop_vars = []
step_output = _to_ndarray_tuple(step_output, "step_output")
new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars")
if len(loop_vars) != len(new_loop_vars):
raise ValueError("The length of loop_vars should be consistent during the loop")
return step_output, new_loop_vars

max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars")
# It should be work as fine if loop_vars are empty I guess,
# but it is semantically unnecessary to include this case.
if len(loop_vars) == 0:
raise ValueError("loop_vars should contain at least one element")

steps = 0
outputs = []
while steps < max_iterations and \
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition
step_output, loop_vars = _func_wrapper(loop_vars)
outputs.append(step_output)
steps += 1
if len(outputs) != steps or len(step_output) != len(outputs[0]):
raise ValueError("step_output are inconsistent on each step")
try:
outputs = list(ndarray.op.stack(*item) for item in zip(*outputs))
except ValueError:
raise ValueError("step_outputs are inconsistent on each step")
return outputs, list(loop_vars)
201 changes: 201 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,204 @@ def check_data(inputs, in_type, msg):
states = states[0]

return (outs, states)

def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"):
"""Run a while loop with user-defined computation and loop condition.
This operator simulates a while loop which iterately does customized computation
as long as the condition is satisfied.
`loop_vars` is a list of Symbols on which the computation uses.
`cond` is a user-defined function as the loop condition.
It consumes `loop_vars`, and produces a scalar MXNet symbol,
indicating the termination of the loop.
The loop ends when `cond` returns false (zero).
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => Symbol`.
`func` is a user-defined function as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
The number of elements, shape, dtype of each element in `step_output` should be consistent.
The `new_loop_vars` should be consistent with `loop_vars` on each step.
The `func` is variadic, and its signature should be
`cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`.
`max_iterations` is a scalar that defines the maximum number of iterations allowed.
This function returns a list of Symbols of length `|step_output| + |loop_vars|`.
The i-th element in the first `|step_output|` ones of the list represent
the i-th `step_output` at all step, stacked along axis 0.
The i-th element in the last `|loop_vars|` ones of the list
represent the final state of each loop variable.
TODO(Junru): writing style: use Symbol or symbol?
Parameters
----------
loop_vars: list of Symbol.
The initial values of the loop variables.
cond: a Python function.
The loop condition.
func: a Python function.
The loop body.
max_iteration: a python int.
Maximum number of iterations.
Returns
-------
outputs: a list of Symbol of length `|step_output| + |loop_vars|`.
The first `|step_output|` Symbols are outputs.
The last `|loop_vars|` Symbols are the final state of loop variables.
TODO(Junru): change the output format
Examples
--------
TODO(Junru): run this
>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: (i + 1, s + i)
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
>>> outputs = mx.sym.contrib.while_loop(loop_vars, cond, func, max_iterations=10)
"""
def _to_python_scalar(inputs, type, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if hasattr(inputs, "asscalar"):
inputs = inputs.asscalar()
try:
inputs = type(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type.__name__))
return inputs

def _to_symbol_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol,
a tuple of mxnet Symbol, into a tuple of Symbol
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, Symbol):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
for item in inputs:
if not isinstance(item, Symbol):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
return inputs

def _cond_wrapper(loop_vars):
result = cond(*loop_vars)
if not isinstance(result, Symbol):
raise ValueError("Return of cond must be a Symbol")
return [], [result]

def _func_wrapper(loop_vars):
"""This wrapper unifies
"func: loop_vars -> new_loop_vars"
and "func: loop_vars -> (step_output, new_loop_vars)"
into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars)
"""
step_output, new_loop_vars = func(*loop_vars)
if step_output is None:
step_output = []
if new_loop_vars is None:
new_loop_vars = []
step_output = _to_symbol_tuple(step_output, "step_output")
new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars")
if len(loop_vars) != len(new_loop_vars):
raise ValueError("The number of loop_vars should be consistent during the loop")
return list(step_output), list(new_loop_vars)

def _create_subgraph(graph_vars, graph_func, subgraph_name):
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
outputs, final_state = graph_func(new_graph_vars)
# first `num_out_data` elements belong to `outputs`
# other elements belong to `final_state`
num_out_data = len(outputs)
num_outputs = len(outputs) + len(final_state)
# nnvm graph does not allow inputs and outputs overlap
id_new_graph_vars = {id(x) for x in new_graph_vars}
make_identity = lambda x: symbol.op.identity(x) if id(x) in id_new_graph_vars else x
# group all outputs of graph_func
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs

def _union_inputs(*graphs):
# Given a list of graphs, each whose inputs are either from loop_vars or other variables.
# 1) calculate a list `inputs`, the union of their inputs.
# 2) for each graph, determine in which indices their inputs reside in `inputs`
# 3) for each variable in the input of `graph`, find which index it is
inputs = [] # List[Symbol], result of 1)
locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, where tuples are results of 2) and 3)
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# also we collect the mapping from var's name to var's loc in loop_vars
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
var_locs = [-1] * len(loop_vars) # results from the third step
for name in graph.list_inputs():
assert name in name_to_input_syms # it should obviously hold
# name -> sym
if name in name_to_loop_vars:
sym = name_to_loop_vars[name]
elif name in name_to_cut_g_syms:
sym = name_to_cut_g_syms[name]
else:
sym = copy.deepcopy(name_to_input_syms[name])
# do 2), and 1) is implicitly done
if id(sym) in input_id_to_loc:
loc = input_id_to_loc[id(sym)]
else:
loc = len(input_id_to_loc)
inputs.append(sym)
input_id_to_loc[id(sym)] = loc
input_locs.append(loc)
# do 3)
if name in name_to_var_locs:
var_locs[name_to_var_locs[name]] = len(input_locs) - 1
locs.append((input_locs, var_locs))
return inputs, locs
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
loop_vars = _to_symbol_tuple(loop_vars, "loop_vars")
# It should be work as fine if loop_vars are empty I guess,
# but it is semantically unnecessary to include this case.
if len(loop_vars) == 0:
raise ValueError("loop_vars should contain at least one element")
# create graph for `cond'
cond_g, num_out_data, num_outputs = \
_create_subgraph(loop_vars, _cond_wrapper, name + "_cond")
assert num_out_data == 0
assert num_outputs == 1
# create graph for `func`
func_g, num_out_data, num_outputs = \
_create_subgraph(loop_vars, _func_wrapper, name + "_func")
# find symbols used in either cond_g or func_g
input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = _union_inputs(cond_g, func_g)
for loc in func_var_locs:
# TODO(Junru): re-examine this
assert loc != -1
result = symbol._internal._while_loop(
# [cond, func_g, *input_syms]
cond_g,
func_g,
*input_syms,
max_iterations=max_iterations,
cond_input_locs=cond_input_locs,
func_input_locs=func_input_locs,
func_var_locs=func_var_locs,
num_out_data=num_out_data,
num_outputs=num_outputs
)
outputs = [result[i] for i in range(num_out_data)]
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars
Loading

0 comments on commit 6976b90

Please sign in to comment.