diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 8fd7dd318758..d430aee90b4e 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1516,7 +1516,8 @@ class SymbolBlock(HybridBlock): >>> print(feat_model(x)) """ @staticmethod - def imports(symbol_file, input_names, param_file=None, ctx=None): + def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False, + ignore_extra=False): """Import model previously saved by `gluon.HybridBlock.export` as a `gluon.SymbolBlock` for use in Gluon. @@ -1530,6 +1531,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): Path to parameter file. ctx : Context, default None The context to initialize `gluon.SymbolBlock` on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represents in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this Block. Returns ------- @@ -1562,7 +1568,7 @@ def imports(symbol_file, input_names, param_file=None, ctx=None): inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names] ret = SymbolBlock(sym, inputs) if param_file is not None: - ret.load_parameters(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved') + ret.load_parameters(param_file, ctx, allow_missing, ignore_extra, True, 'saved') return ret def __repr__(self):