Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 20, 2018
1 parent 2f8a794 commit 2b46358
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ def _func_wrapper(loop_vars):
return stacked_outputs, list(loop_vars)

def ifelse(cond, then_func, else_func, inputs):
"""Run a if-then-else using user-defined condition and computation
"""Run an if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of NDArrays on which the condition and computations reply on.
`inputs` is a list of NDArrays on which the condition and computations rely on.
`cond` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet NDArray,
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,12 @@ def _union_inputs(*graphs):
return outputs, final_loop_vars

def ifelse(cond, then_func, else_func, inputs, name="ifelse"):
"""Run a if-then-else using user-defined condition and computation
"""Run an if-then-else using user-defined condition and computation
This operator simulates a if-like branch which chooses to do one of
the two customized computations according to the specified condition.
`inputs` is a list of Symbols on which the condition and computations reply on.
`inputs` is a list of Symbols on which the condition and computations rely on.
`cond` is a user-defined function, used as the if condition.
It consumes `inputs`, and produces a scalar MXNet symbol,
Expand Down
1 change: 1 addition & 0 deletions src/operator/control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ static bool IfelseShape(const nnvm::NodeAttrs& attrs,
params.then_input_locs, true);
bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \
params.else_input_locs, true);
sync_out_out(&then_out_shape, &else_out_shape, is_udf);
return succ_0 && succ_1 && succ_2;
}

Expand Down
13 changes: 13 additions & 0 deletions src/operator/subgraph_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ bool sync_in_in(const nnvm::Tuple<dim_t> &input_locs,
return true;
}

template <typename T>
bool sync_out_out(std::vector<T> *out_1,
std::vector<T> *out_2,
std::function<bool(const T &)> is_empty) {
CHECK_EQ(out_1->size(), out_2->size());
for (size_t i = 0; i < out_1->size(); ++i) {
T &x = out_1->at(i);
T &y = out_2->at(i);
fill_value(&x, &y, is_empty(x), is_empty(y));
}
return true;
}

/*
* This contains the states for running a loop and provides methods
* of running the subgraph computation for an iteration.
Expand Down

0 comments on commit 2b46358

Please sign in to comment.