Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add enum to activation
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 17, 2015
1 parent 83d89f3 commit 137d109
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ endif
BIN = test/api_registry_test test/test_storage
OBJ = narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
* The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
*
* \param sym symbol handle
* \param num_args numbe of input arguments.
* \param keys the key of keyword args (optional)
* \param arg_ind_ptr the head pointer of the rows in CSR
Expand Down
35 changes: 19 additions & 16 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access, too-many-locals
"""Symbol support of mxnet"""
from __future__ import absolute_import

Expand Down Expand Up @@ -162,7 +162,8 @@ def infer_shape(self, *args, **kwargs):
The order is in the same order as list_returns()
"""
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument shapes either by positional or kwargs way.')
raise ValueError('Can only specify known argument \
shapes either by positional or kwargs way.')
sdata = []
indptr = [0]
if len(args) != 0:
Expand All @@ -188,21 +189,23 @@ def infer_shape(self, *args, **kwargs):
out_shape_ndim = ctypes.POINTER(mx_uint)()
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferShape(
self.handle, len(indptr) - 1,
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
ctypes.byref(arg_shape_size),
ctypes.byref(arg_shape_ndim),
ctypes.byref(arg_shape_data),
ctypes.byref(out_shape_size),
ctypes.byref(out_shape_ndim),
ctypes.byref(out_shape_data),
check_call(_LIB.MXSymbolInferShape( \
self.handle, len(indptr) - 1, \
c_array(ctypes.c_char_p, keys), \
c_array(mx_uint, indptr), \
c_array(mx_uint, sdata), \
ctypes.byref(arg_shape_size), \
ctypes.byref(arg_shape_ndim), \
ctypes.byref(arg_shape_data), \
ctypes.byref(out_shape_size), \
ctypes.byref(out_shape_ndim), \
ctypes.byref(out_shape_data), \
ctypes.byref(complete)))
if complete.value != 0:
arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)]
out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)]
arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) \
for i in range(arg_shape_size.value)]
out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) \
for i in range(out_shape_size.value)]
return (arg_shapes, out_shapes)
else:
return (None, None)
Expand All @@ -216,6 +219,6 @@ def debug_str(self):
Debug string of the symbol.
"""
debug_str = ctypes.c_char_p()
check_call(_LIB.MXSymbolPrint(
check_call(_LIB.MXSymbolPrint( \
self.handle, ctypes.byref(debug_str)))
return debug_str.value
5 changes: 3 additions & 2 deletions src/operator/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
Expand All @@ -28,8 +29,8 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
// use int for enumeration
int type;
DMLC_DECLARE_PARAMETER(ActivationParam) {
// TODO(bing) support enum, str->int mapping
DMLC_DECLARE_FIELD(type).set_default(kReLU);
DMLC_DECLARE_FIELD(type).set_default(kReLU).add_enum("relu", kReLU).\
add_enum("sigmoid", kSigmoid).add_enum("tanh", kTanh);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/operator/elementwise_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
Expand Down
1 change: 1 addition & 0 deletions src/operator/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
Expand Down

0 comments on commit 137d109

Please sign in to comment.