Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Initial commit for Ifelse
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 17, 2018
1 parent aa9722d commit 322cb05
Show file tree
Hide file tree
Showing 8 changed files with 925 additions and 110 deletions.
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
quantize
foreach
while_loop
ifelse
```

## API Reference
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
quantize
foreach
while_loop
ifelse
```

## API Reference
Expand Down
97 changes: 95 additions & 2 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
pass

__all__ = ["rand_zipfian", "foreach", "while_loop"]
__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"]

# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
Expand Down Expand Up @@ -192,7 +192,6 @@ def check_input(inputs, in_type, msg):
outputs = outputs[0]
return (outputs, states)


def while_loop(cond, func, loop_vars, max_iterations=None):
"""Run a while loop with user-defined computation and loop condition.
Expand Down Expand Up @@ -358,3 +357,97 @@ def _func_wrapper(loop_vars):
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)

def ifelse(cond, then_func, else_func, inputs):
"""Run a if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of NDArrays on which the condition and computations reply on.
`cond` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet NDArray,
indicating which branch of computation should be used.
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => NDArray`.
`then_func` is a user-defined function, used as computation of the then branch.
It consumes `inputs`, and produces `outputs`.
The `then_func` is variadic, and its signature should be
`then_func(*loop_vars) => List[NDArray]`.
`else_func` is a user-defined function, used as computation of the else branch.
It also consumes `inputs`, and produces `outputs`.
The `else_func` is variadic, and its signature should be
`else_func(*loop_vars) => List[NDArray]`.
The `outputs` produces by `then_func` and `else_func` should have the same number
of elements, all of which should be in the same shape, of the same dtype and stype.
This function returns a list of NDArrays, representing the computation result.
Parameters
----------
cond: a Python function.
The branch condition.
then_func: a Python function.
The computation to be executed if `cond` is true.
else_func: a Python function.
The computation to be executed if `cond` is false.
inputs: list of NDArrays.
The variables fed to `cond`, `then_func` and `else_func`.
Returns
-------
outputs: a list of NDArrays, representing the result of computation.
Examples
--------
>>> cond = lambda a, b: a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> inputs = (mx.nd.array([1]), mx.nd.array([2]))
>>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs)
>>> outputs[0]
[42.]
<NDArray 1 @cpu(0)>
"""
def _to_python_scalar(inputs, type_, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if hasattr(inputs, "asscalar"):
inputs = inputs.asscalar()
try:
inputs = type_(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__))
return inputs

def _to_ndarray_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray,
a tuple of mxnet NDArray, into a tuple of NDArray
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, ndarray.NDArray):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
for item in inputs:
if not isinstance(item, ndarray.NDArray):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
return inputs

inputs = _to_ndarray_tuple(inputs, "inputs")
if len(inputs) == 0:
raise ValueError("inputs should contain at least one element")
branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond")
if branch:
outputs = then_func(*inputs)
outputs = _to_ndarray_tuple(outputs, "outputs of then_func")
else:
outputs = else_func(*inputs)
outputs = _to_ndarray_tuple(outputs, "outputs of else_func")
return list(outputs)
153 changes: 152 additions & 1 deletion python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..base import SymbolHandle, _as_list
from ..attribute import AttrScope

__all__ = ["rand_zipfian", "foreach", "while_loop"]
__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"]

def rand_zipfian(true_classes, num_sampled, range_max):
"""Draw random samples from an approximately log-uniform or Zipfian distribution.
Expand Down Expand Up @@ -551,3 +551,154 @@ def _union_inputs(*graphs):
outputs = [result[i] for i in range(num_out_data)]
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars

def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
"""Run a if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of Symbols on which the condition and computations reply on.
`cond` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet symbol,
indicating which branch of computation should be used.
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => Symbol`.
`then_func` is a user-defined function, used as computation of the then branch.
It consumes `inputs`, and produces `outputs`.
The `then_func` is variadic, and its signature should be
`then_func(*loop_vars) => List[Symbol]`.
`else_func` is a user-defined function, used as computation of the else branch.
It also consumes `inputs`, and produces `outputs`.
The `else_func` is variadic, and its signature should be
`else_func(*loop_vars) => List[Symbol]`.
The `outputs` produces by `then_func` and `else_func` should have the same number
of elements, all of which should be in the same shape, of the same dtype and stype.
This function returns a list of symbols, representing the computation result.
Parameters
----------
cond: a Python function.
The branch condition.
then_func: a Python function.
The computation to be executed if `cond` is true.
else_func: a Python function.
The computation to be executed if `cond` is false.
inputs: list of Symbols.
The variables fed to `cond`, `then_func` and `else_func`.
Returns
-------
outputs: a list of Symbols, representing the result of computation.
Examples
--------
>>> cond = lambda a, b: a * b < 5
>>> then_func = lambda a, b: (a + 5) * (b + 5)
>>> else_func = lambda a, b: (a - 5) * (b - 5)
>>> inputs = (mx.sym.var('a'), mx.sym.var('b'))
>>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs)
"""
def _to_symbol_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol,
a tuple of mxnet Symbol, into a tuple of Symbol
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, Symbol):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
for item in inputs:
if not isinstance(item, Symbol):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
return inputs

def _create_subgraph(graph_vars, graph_func, subgraph_name):
with AttrScope(__subgraph_name__=subgraph_name):
# create new variables with the same name,
# them feed them to the given func
new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
outputs = graph_func(*new_graph_vars)
outputs = _to_symbol_tuple(outputs, "outputs")
num_outputs = len(outputs)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
graph = symbol.Group(list(map(make_identity, outputs)))
return graph, num_outputs

def _union_inputs(*graphs):
# Given a list of graphs, each whose inputs are either from input_vars or other variables.
# 1) calculate a list `inputs`, the union of their inputs.
# 2) for each graph, determine in which indices their inputs reside in `inputs`
# 3) for each variable in the input of `graph`, find which index it is
inputs = [] # List[Symbol], result of 1)
locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples,
# where tuples are results of 2) and 3)
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some input_vars are inputs to `graph`, some are not
name_to_input_vars = {sym.name: sym for sym in inputs}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
for name in graph.list_inputs():
assert name in name_to_input_syms # it should obviously hold
# name -> sym
if name in name_to_input_vars:
sym = name_to_input_vars[name]
elif name in name_to_cut_g_syms:
sym = name_to_cut_g_syms[name]
else:
sym = copy.deepcopy(name_to_input_syms[name])
# do 2), and 1) is implicitly done
if id(sym) in input_id_to_loc:
loc = input_id_to_loc[id(sym)]
else:
loc = len(input_id_to_loc)
inputs.append(sym)
input_id_to_loc[id(sym)] = loc
input_locs.append(loc)
locs.append(input_locs)
return inputs, locs
inputs = _to_symbol_tuple(inputs, "inputs")
if len(inputs) == 0:
raise ValueError("loop_vars should contain at least one element")
# create graph for `cond'
cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond")
if num_outputs != 1:
raise ValueError("cond should always produce a single output")
# create graph for `then`
then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then")
# create graph for `else`
else_g, else_num_outputs = _create_subgraph(inputs, else_func, name + "_else")
if then_num_outputs != else_num_outputs:
raise ValueError("Number of outputs differs between then-branch and else-branch")
# find symbols used in either cond_g or func_g
input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \
_union_inputs(cond_g, then_g, else_g)
result = symbol._internal._ifelse(
# [cond, then_g, else_g, *input_syms]
cond_g,
then_g,
else_g,
*input_syms,
cond_input_locs=cond_input_locs,
then_input_locs=then_input_locs,
else_input_locs=else_input_locs,
num_outputs=then_num_outputs
)
result = _to_symbol_tuple(result, "result")
return list(result)
Loading

0 comments on commit 322cb05

Please sign in to comment.