Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #3 from dmlc/master
Browse files Browse the repository at this point in the history
merge dmlc/master
  • Loading branch information
mli committed Sep 21, 2015
2 parents 8149248 + f69c194 commit 74bd455
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
sys.path.insert(0, "../../python/")
sys.path.append("../../tests/python/common")
# import library
import logging
import mxnet as mx
import get_data
import time
Expand Down Expand Up @@ -59,10 +60,6 @@
[39] train-error:0.00125879 val-error:0.0833
[40] train-error:0.000699329 val-error:0.0842
"""
def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]


np.random.seed(1812)

Expand Down Expand Up @@ -178,11 +175,14 @@ def RandomInit(narray):
preprocess_threads=1)

def test_cifar():
model = mx.model.MXNetModel(ctx=mx.gpu(),
symbol=loss, data=(batch_size, 3, 28, 28),
optimizer="sgd", num_round = epoch, batch_size = batch_size,
learning_rate=0.05, momentum=0.9, weight_decay=0.00001)
model.fit(X=train_dataiter, eval_set=test_dataiter, eval_metric=CalAcc)
logging.basicConfig(level=logging.DEBUG)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)
# get model from symbol
model = mx.model.FeedForward(ctx=mx.gpu(), symbol=loss, num_round = epoch,
learning_rate=0.05, momentum=0.9, wd=0.00001)
model.fit(X=train_dataiter, eval_data=test_dataiter)


if __name__ == "__main__":
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def _split_input_slice(input_shape, num_split):
slices = []
shapes = []
for k in range(num_split):
begin = min(k * step, batch_size)
end = min((k+1) * step, batch_size)
begin = int(min(k * step, batch_size))
end = int(min((k+1) * step, batch_size))
if begin == end:
raise ValueError('Too many slices such that some splits are empty')
slices.append(slice(begin, end))
Expand Down

0 comments on commit 74bd455

Please sign in to comment.