Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix failed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 11, 2017
1 parent 4da22db commit 35bf4d0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 26 deletions.
30 changes: 6 additions & 24 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
----------
>>> x = mx.sym.Variable('x')
>>> y = mx.sym.FullyConnected(x, num_hidden=4)
>>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[])
>>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null')
>>> exe.forward()
[<NDArray 5x4 @cpu(0)>]
>>> exe.outputs[0].asnumpy()
Expand Down Expand Up @@ -1166,15 +1166,6 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
executor : mxnet.Executor
The generated executor
"""
# listed_arguments = self.list_arguments() # read-only args
# listed_aux_states = self.list_auxiliary_states() # aux states

# attr_dict = None
# if type_dict is None:
# attr_dict = self.attr_dict()
# type_dict = {k: mx_real_t for k in listed_arguments
# if k not in attr_dict or '__dtype__' not in attr_dict[k]}

num_provided_arg_types = 0
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names
provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types
Expand Down Expand Up @@ -1211,9 +1202,13 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
provided_req_type_list_len = 0
provided_grad_req_types = [c_str(grad_req)]
elif isinstance(grad_req, list):
if len(grad_req) == 0:
raise RuntimeError('grad_req in simple_bind cannot be an empty list')
provided_grad_req_types = [c_str(item) for item in grad_req]
provided_req_type_list_len = len(provided_grad_req_types)
elif isinstance(grad_req, dict):
if len(grad_req) == 0:
raise RuntimeError('grad_req in simple_bind cannot be an empty dict')
provided_grad_req_names = []
provided_grad_req_types = []
for k, v in grad_req.items():
Expand All @@ -1223,19 +1218,6 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None,
provided_req_type_list_len = len(provided_grad_req_types)
provided_grad_req_types = c_array(ctypes.c_char_p, provided_grad_req_types)

# if group2ctx is not None:
# if attr_dict is None:
# attr_dict = self.attr_dict()
# arg_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
# if name in attr_dict and '__ctx_group__' in attr_dict[name]
# else ctx for name in listed_arguments]
# aux_ctx = [group2ctx.get(attr_dict[name]['__ctx_group__'], ctx)
# if name in attr_dict and '__ctx_group__' in attr_dict[name]
# else ctx for name in listed_aux_states]
# else:
# arg_ctx = [ctx] * len(listed_arguments)
# aux_ctx = [ctx] * len(listed_aux_states)

num_ctx_map_keys = mx_uint(0)
ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)()
ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)()
Expand Down Expand Up @@ -1359,7 +1341,7 @@ def simple_bind_v1(self, ctx,
----------
>>> x = mx.sym.Variable('x')
>>> y = mx.sym.FullyConnected(x, num_hidden=4)
>>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[])
>>> exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null')
>>> exe.forward()
[<NDArray 5x4 @cpu(0)>]
>>> exe.outputs[0].asnumpy()
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
}

// initialize arg_grad_ctx_vec and grad_req_type_vec
std::vector<Context> arg_grad_ctx_vec(in_arg_names.size());
std::vector<Context> arg_grad_ctx_vec(in_arg_names.size(), ctx);
std::vector<OpReqType> grad_req_type_vec(in_arg_names.size(), kNullOp);
if ("none" != grad_req_type) {
for (size_t i = 0; i < in_arg_names.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_reshape():
x = mx.sym.Variable('x')
y = mx.sym.FullyConnected(x, num_hidden=4)

exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req=[])
exe = y.simple_bind(mx.cpu(), x=(5,4), grad_req='null')
exe.arg_arrays[0][:] = 1
exe.arg_arrays[1][:] = mx.nd.ones((4,4))
exe.arg_arrays[2][:] = 0
Expand Down

0 comments on commit 35bf4d0

Please sign in to comment.