Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #133 from tqchen/master
Browse files Browse the repository at this point in the history
[MODEL] Allow extra params
  • Loading branch information
tqchen committed Sep 22, 2015
2 parents 0cb9cb6 + 4430ae1 commit ece3d50
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 5 deletions.
17 changes: 17 additions & 0 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,32 @@ class FeedForward(BASE_ESTIMATOR):
aux_params : dict of str to NDArray, optional
Model parameter, dict of name to NDArray of net's auxiliary states.
allow_extra_params : boolean, optional
Whether allow extra parameters that are not needed by symbol
to be passed by aux_params and arg_params.
If this is True, no error will be thrown when aux_params and arg_params
contain extra parameters than needed.
**kwargs : dict
The additional keyword arguments passed to optimizer.
"""
def __init__(self, symbol, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
arg_params=None, aux_params=None,
allow_extra_params=False,
**kwargs):
# check if symbol contain duplicated names.
_check_arguments(symbol)
# rematch parameters to delete useless ones
if allow_extra_params:
if arg_params:
arg_names = set(symbol.list_arguments())
arg_params = {k : v for k, v in arg_params.items()
if k in arg_names}
if aux_params:
aux_names = set(symbol.list_auxiliary_states())
aux_params = {k : v for k, v in aux_params.items()
if k in aux_names}
# basic configuration
self.symbol = symbol
if ctx is None:
Expand Down
10 changes: 10 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def _compose(self, *args, **kwargs):
self.handle, name, num_args, keys, args))

def __getitem__(self, index):
if isinstance(index, string_types):
idx = None
for i, name in enumerate(self.list_outputs()):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
handle = SymbolHandle()
Expand Down
12 changes: 9 additions & 3 deletions tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def test_mlp():
logging.info('final accuracy = %f', acc1)
assert(acc1 > 0.95)

# predict internal featuremaps
internals = softmax.get_internals()
fc2 = internals['fc2_output']
mfeat = mx.model.FeedForward(symbol=fc2,
arg_params=model.arg_params,
aux_params=model.aux_params,
allow_extra_params=True)
feat = mfeat.predict(val_dataiter)
assert feat.shape == (10000, 64)
# pickle the model
smodel = pickle.dumps(model)
model2 = pickle.loads(smodel)
Expand All @@ -79,9 +88,6 @@ def test_mlp():
assert np.sum(np.abs(prob - prob3)) == 0

# save model explicitly



model.save(prefix, 128)
model4 = mx.model.FeedForward.load(prefix, 128)
prob4 = model4.predict(val_dataiter)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_symbol_internal():
'fc1_weight', 'fc1_bias',
'fc2_weight', 'fc2_bias']
internal = net1.get_internals()
nmap = {x: i for i, x in enumerate(internal.list_outputs())}
fc1 = internal[nmap['fc1_output']]
print internal.list_outputs()
fc1 = internal['fc1_output']
assert fc1.list_arguments() == oldfc.list_arguments()


Expand Down

0 comments on commit ece3d50

Please sign in to comment.