Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
SymbolBlock.imports ignore_extra & allow_missing (#19157)
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky authored Sep 17, 2020
1 parent adbc17b commit 30ae04a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,8 @@ class SymbolBlock(HybridBlock):
>>> print(feat_model(x))
"""
@staticmethod
def imports(symbol_file, input_names, param_file=None, ctx=None):
def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
ignore_extra=False):
"""Import model previously saved by `gluon.HybridBlock.export`
as a `gluon.SymbolBlock` for use in Gluon.
Expand All @@ -1530,6 +1531,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
Path to parameter file.
ctx : Context, default None
The context to initialize `gluon.SymbolBlock` on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
Returns
-------
Expand Down Expand Up @@ -1562,7 +1568,7 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
ret = SymbolBlock(sym, inputs)
if param_file is not None:
ret.load_parameters(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
ret.load_parameters(param_file, ctx, allow_missing, ignore_extra, True, 'saved')
return ret

def __repr__(self):
Expand Down

0 comments on commit 30ae04a

Please sign in to comment.