Skip to content

Commit

Permalink
Merge pull request apache#11 from antinucleon/dato
Browse files Browse the repository at this point in the history
[Fix 3546] & [Fix ELU Op]
  • Loading branch information
antinucleon committed Jan 19, 2016
2 parents 71201c7 + 4288a3d commit 51b7ee6
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,4 @@ TORCH_PATH = $(HOME)/torch
# whether to use sframe integration. This requires build sframe
# [email protected]:dato-code/SFrame.git
# SFRAME_PATH = $(HOME)/SFrame
# MXNET_PLUGINS += plugin/sframe/SFrame.mk
# MXNET_PLUGINS += plugin/sframe/plugin.mk
File renamed without changes.
5 changes: 3 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
for idx in range(len(param_arrays)):
param_on_devs = param_arrays[idx]
kvstore.init(idx, arg_params[param_names[idx]])

if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)

Expand Down Expand Up @@ -202,7 +201,6 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,

if update_on_kvstore:
kvstore.set_optimizer(optimizer)

# Now start training
for epoch in range(begin_epoch, end_epoch):
# Training phase
Expand Down Expand Up @@ -416,6 +414,9 @@ def __init__(self, symbol, ctx=None,
ctx = [cpu()]
elif isinstance(ctx, Context):
ctx = [ctx]
# disable multi-cpu data parallelism because blas will use all cpu resource
if ctx[0].device_type == "cpu" and len(ctx) > 1:
ctx = [cpu()]
self.ctx = ctx
# training parameters
self.num_epoch = num_epoch
Expand Down
2 changes: 1 addition & 1 deletion src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct elu {
struct elu_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType a) {
return DType(x > 0.0f ? 1.0f : a * expf(x));
return DType(x > 0.0f ? 1.0f : a + x);
}
};

Expand Down

0 comments on commit 51b7ee6

Please sign in to comment.