This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
block.py
1981 lines (1722 loc) · 82.3 KB
/
block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable= arguments-differ, too-many-lines, reimported
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
import enum
import ctypes
import copy
import warnings
import weakref
from collections import OrderedDict, defaultdict
import contextlib
import contextvars
import re
import json
import numpy as np
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, device as _device
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray, get_dtype_name
from .parameter import Parameter, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np, ndarray as nd
from .. util import is_np_array, np_shape, np_array, wrap_ctx_to_device_func
_naming_counter = contextvars.ContextVar('namecounter')
_prefix = contextvars.ContextVar('prefix', default='')
@contextlib.contextmanager
def _block_scope(block):
"""Append the classname of the current Block to the symbolic and memory profiler name scopes."""
name = type(block).__name__.lower()
counter = _naming_counter.get(None)
if counter is not None:
count = counter.get(name, 0)
counter[name] = count + 1
name = f'{name}{count}'
counter_token = _naming_counter.set({})
prefix_token = _prefix.set(_prefix.get() + name + '_')
with _name.Prefix(_prefix.get()):
with _profiler.scope(name + ':'):
yield
_naming_counter.reset(counter_token)
_prefix.reset(prefix_token)
def _gather_type_device_info(args):
"""Analyze the elements inside the nested args object and find:
- If there exists ndarray
- If there exists symbol
- All devices appearing in args
Parameters
----------
args : list or NDArray or Symbol
Could be a nested architecture.
Returns
-------
has_symbol : bool
Whether the elements in args contains symbols
has_ndarray : bool
Whether the elements in args contains ndarrays
device_set : set of mxnet.device.Device
Contains all possible devices of the inner ndarrays in args. Can be empty if there is no
ndarray inside args.
first_device : mxnet.device.Device or None
Device of the first appeared NDArray (for backward-compatibility)
"""
if isinstance(args, NDArray):
return False, True, {args.device}, args.device
elif isinstance(args, Symbol):
return True, False, set(), None
elif isinstance(args, (list, tuple)):
has_symbol = False
has_ndarray = False
device_set = set()
first_device = None
for ele in args:
ele_has_sym, ele_has_nd, ele_device_set, ele_first_device =\
_gather_type_device_info(ele)
has_symbol = has_symbol or ele_has_sym
has_ndarray = has_ndarray or ele_has_nd
if first_device is None and ele_first_device is not None:
first_device = ele_first_device
device_set = device_set | ele_device_set
if has_symbol and has_ndarray:
break
return has_symbol, has_ndarray, device_set, first_device
else:
return False, False, set(), None
def _flatten(args, inout_str):
"""Parse the arguments into a flattened list + an additional format array.
The format array stores the structure of the original arguments to help reconstruct the inputs.
Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
inout_str : str
The name of the HybridBlock
Returns
-------
flat : list of Symbol or NDArray
The flatten version of the input args.
fmts : (nested) list of ints
Stores the format information of the original structured args.
"""
if isinstance(args, NDArray):
return [args], int(0)
if isinstance(args, Symbol):
length = len(args.list_outputs())
length = length if length > 1 else 0
return [args], int(length)
if args is None:
return [None], int(-1)
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the input of HybridBlock {}"
" must be (nested) list of Symbol"
" or NDArray, "
"but got {} of type {}".format(inout_str, str(args), str(type(args))))
flat = []
fmts = []
for i in args:
arg, fmt = _flatten(i, inout_str)
flat.extend(arg)
fmts.append(fmt)
return flat, fmts
def _regroup(args, fmt):
"""Reconstruct the structured arguments based on the flattened version.
Parameters
----------
args : NDArray, Symbol, or (nested) list of Symbol or NDArray
We allow None inside the args.
fmt : (nested) list of ints
Stores the format information of the original structured args.
Returns
-------
ret : NDArray, Symbol, or (nested) list of Symbol or NDArray
"""
def _merger(args, fmt):
"""Recursive call to merge the arguments"""
if isinstance(fmt, int):
if fmt < -1:
raise ValueError("Unsupported encoded format {}.".format(fmt))
if fmt == 0:
return args[0], args[1:]
if fmt == -1:
if args[0] is not None:
raise ValueError('We do not support passing types that are not None'
' when the initial HybridBlock has received NoneType and'
' has been hybridized.'
' Received arg = {}, fmt = {}.'.format(args[0], fmt))
return None, args[1:]
else:
return args[:fmt], args[fmt:]
if not isinstance(args, (list, tuple)):
raise ValueError("When hybridized, the output of HybridBlock must be (nested)"
" list of Symbol or NDArray, "
"but got {} of type {}".format(args, type(args)))
ret = []
for i in fmt:
res, args = _merger(args, i)
ret.append(res)
return ret, args
return _merger(args, fmt)[0]
class Block:
"""Base class for all neural network layers and models. Your models should
subclass this class.
:py:class:`Block` can be nested recursively in a tree structure. You can create and
assign child :py:class:`Block` as regular attributes::
import mxnet as mx
from mxnet.gluon import Block, nn
class Model(Block):
def __init__(self, **kwargs):
super(Model, self).__init__(**kwargs)
self.dense0 = nn.Dense(20)
self.dense1 = nn.Dense(20)
def forward(self, x):
x = mx.npx.relu(self.dense0(x))
return mx.npx.relu(self.dense1(x))
model = Model()
model.initialize(device=mx.cpu(0))
model(mx.np.zeros((10, 10), device=mx.cpu(0)))
Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params`
will collect their Parameters recursively. You can also manually register
child blocks with :py:meth:`register_child`.
"""
def __init__(self):
self._children = OrderedDict()
self._reg_params = {}
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join([' ({key}): {block}'.format(key=key,
block=_indent(block.__repr__(), 2))
for key, block in self.__dict__.items() if isinstance(block, Block)])
return s.format(name=self.__class__.__name__, modstr=modstr)
def __setattr__(self, name, value):
"""Registers parameters."""
if hasattr(self, name):
existing = getattr(self, name)
if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)):
raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \
'is not allowed.'.format(
name=name, type1=type(existing), type2=type(value)))
if isinstance(value, Block):
self.register_child(value, name)
elif isinstance(value, Parameter):
self._reg_params[name] = value
super(Block, self).__setattr__(name, value)
def _check_container_with_block(self):
children = set(self._children.values())
def _find_unregistered_block_in_container(data):
# Find whether a nested container structure contains Blocks
if isinstance(data, (list, tuple)):
for ele in data:
if _find_unregistered_block_in_container(ele):
return True
return False
elif isinstance(data, dict):
for _, v in data.items():
if _find_unregistered_block_in_container(v):
return True
return False
elif isinstance(data, Block):
return not data in (c() for c in children)
else:
return False
for k, v in self.__dict__.items():
if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'):
if _find_unregistered_block_in_container(v):
warnings.warn('"{name}" is an unregistered container with Blocks. '
'Note that Blocks inside the list, tuple or dict will not be '
'registered automatically. Make sure to register them using '
'register_child() or switching to '
'nn.Sequential/nn.HybridSequential instead. '
.format(name=self.__class__.__name__ + "." + k), stacklevel=3)
def _alias(self):
return self.__class__.__name__.lower()
@property
def params(self):
"""Returns this :py:class:`Block`'s parameter dictionary (does not include its
children's parameters)."""
return self._reg_params
def collect_params(self, select=None):
"""Returns a :py:class:`Dict` containing this :py:class:`Block` and all of its
children's Parameters(default), also can returns the select :py:class:`Dict`
which match some given regular expressions.
For example, collect the specified parameters in ['conv1.weight', 'conv1.bias', 'fc.weight',
'fc.bias']::
model.collect_params('conv1.weight|conv1.bias|fc.weight|fc.bias')
or collect all parameters whose names end with 'weight' or 'bias', this can be done
using regular expressions::
model.collect_params('.*weight|.*bias')
Parameters
----------
select : str
regular expressions
Returns
-------
The selected :py:class:`Dict`
"""
# We need to check here because blocks inside containers are not supported.
self._check_container_with_block()
return self._collect_params_with_prefix(select=select)
def _collect_params_with_prefix(self, prefix='', select=None):
if prefix:
prefix += '.'
if select is None:
ret = {prefix + key : val for key, val in self._reg_params.items()}
else:
pattern = re.compile(select)
ret = {prefix + key : val for key, val in self._reg_params.items() if pattern.match(prefix + key)}
for name, child in self._children.items():
ret.update(child()._collect_params_with_prefix(prefix + name, select))
return ret
def save_parameters(self, filename, deduplicate=False):
"""Save parameters to file.
Saved parameters can only be loaded with `load_parameters`. Note that this
method only saves parameters, not model structure. If you want to save
model structures, please use :py:meth:`HybridBlock.export`.
Parameters
----------
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.
References
----------
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
params = self._collect_params_with_prefix()
if deduplicate:
# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}
arg_dict = {key: val._reduce() for key, val in params.items()}
if is_np_array():
_mx_npx.savez(filename, **arg_dict)
else:
ndarray.save(filename, arg_dict)
@wrap_ctx_to_device_func
def load_parameters(self, filename, device=None, allow_missing=False,
ignore_extra=False, cast_dtype=False, dtype_source='current'):
"""Load parameters from file previously saved by `save_parameters`.
Parameters
----------
filename : str
Path to parameter file.
device : Device or list of Device, default cpu()
Device(s) to initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this Block.
cast_dtype : bool, default False
Cast the data type of the NDArray loaded from the checkpoint to the dtype
provided by the Parameter if any.
dtype_source : str, default 'current'
must be in {'current', 'saved'}
Only valid if cast_dtype=True, specify the source of the dtype for casting
the parameters
References
----------
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
if is_np_array():
# failure may happen when loading parameters saved as NDArrays within
# NumPy semantics. Check the failure type and recover from it if it happens.
try:
loaded = _mx_npx.load(filename)
except MXNetError as e:
err_msg = str(e)
if 'is_np_shape' in err_msg:
# Loading failure due to parameters saved without numpy semantics.
# Temporarily disable numpy semantics and load parameters. After it's
# done, resume the numpy semantics. This is fine because the cases
# numpy ndarray covers is a superset of the legacy ndarray's.
with np_array(False):
with np_shape(False):
loaded_nds = ndarray.load(filename)
assert isinstance(loaded_nds, dict),\
'expecting a dict type, got {}'.format(str(type(loaded_nds)))
loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
else:
raise ValueError(err_msg)
else:
loaded = ndarray.load(filename)
if not loaded:
return
full_dict = {'params': loaded, 'filename': filename}
self.load_dict(full_dict, device, allow_missing, ignore_extra, cast_dtype, dtype_source)
def load_dict(self, param_dict, device=None, allow_missing=False,
ignore_extra=False, cast_dtype=False, dtype_source="current"):
"""Load parameters from dict
Parameters
----------
param_dict : dict
Dictionary containing model parameters
device : Device, optional
Device context on which the memory is allocated. Default is
`mxnet.device.current_device()`.
allow_missing : bool, default False
Whether to silently skip loading parameters not represented in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not
present in this dict.
cast_dtype : bool, default False
Cast the data type of the NDArray loaded from the checkpoint to the dtype
provided by the Parameter if any
dtype_source : str, default 'current'
must be in {'current', 'saved'}
Only valid if cast_dtype=True, specify the source of the dtype for casting
the parameters
"""
if isinstance(param_dict.get('filename'), str):
# pass from load_parameters
filename = param_dict['filename']
param_dict = param_dict['params']
else:
filename = None
params = self.collect_params()
error_str = f"file: {filename}" if filename else "param_dict"
loaded = {k[4:] if k.startswith('arg:') or k.startswith('aux:') else k: v \
for k, v in param_dict.items()}
if not allow_missing:
params_inv = defaultdict(list)
for k, v in params.items():
params_inv[v].append(k)
for name, param in params.items():
assert any(p in loaded for p in params_inv[param]), \
f"Parameter '{name}' is missing in '{error_str}', which contains parameters: {_brief_print_list(loaded.keys())}. " \
"Set allow_missing=True to ignore missing parameters."
if device is None:
device = _device.current_device()
for name in loaded:
if not ignore_extra and name not in params:
raise ValueError(
f"Parameter '{name}' loaded from '{error_str}' is not present in Dict, " \
f"which contains parameters {_brief_print_list(params.keys())}. Set ignore_extra=True to ignore. ")
if name in params:
param = loaded[name]
if isinstance(param, np.ndarray):
param = _mx_np.array(param) if is_np_array() else nd.array(param)
params[name]._load_init(param, device, cast_dtype=cast_dtype, dtype_source=dtype_source)
def register_child(self, block, name=None):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
if name is None:
name = str(len(self._children))
self._children[name] = weakref.ref(block)
def register_forward_pre_hook(self, hook):
r"""Registers a forward pre-hook on the block.
The hook function is called immediately before :func:`forward`.
It should not modify the input or output.
Parameters
----------
hook : callable
The forward hook function of form `hook(block, input) -> None`.
Returns
-------
:class:`mxnet.gluon.utils.HookHandle`
"""
handle = HookHandle()
handle.attach(self._forward_pre_hooks, hook)
return handle
def register_forward_hook(self, hook):
r"""Registers a forward hook on the block.
The hook function is called immediately after :func:`forward`.
It should not modify the input or output.
Parameters
----------
hook : callable
The forward hook function of form `hook(block, input, output) -> None`.
Returns
-------
:class:`mxnet.gluon.utils.HookHandle`
"""
handle = HookHandle()
handle.attach(self._forward_hooks, hook)
return handle
def apply(self, fn):
r"""Applies ``fn`` recursively to every child block as well as self.
Parameters
----------
fn : callable
Function to be applied to each submodule, of form `fn(block)`.
Returns
-------
this block
"""
for cld in self._children.values():
cld().apply(fn)
fn(self)
return self
@wrap_ctx_to_device_func
def initialize(self, init=initializer.Uniform(), device=None, verbose=False,
force_reinit=False):
"""Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
Parameters
----------
init : Initializer
Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``.
Otherwise, :py:meth:`Parameter.init` takes precedence.
device : Device or list of Device
Keeps a copy of Parameters on one or many device(s).
verbose : bool, default False
Whether to verbosely print out details on initialization.
force_reinit : bool, default False
Whether to force re-initialization if parameter is already initialized.
"""
params = self.collect_params()
if verbose:
init.set_verbosity(verbose=verbose)
for v in params.values():
v.initialize(None, device, init, force_reinit=force_reinit)
def save(self, prefix):
"""Save the model architecture and parameters to load again later
Saves the model architecture as a nested dictionary where each Block
in the model is a dictionary and its children are sub-dictionaries.
Each Block is uniquely identified by Block class name and a unique ID.
We save each Block's parameter UUID to restore later in order to match
the saved parameters.
Recursively traverses a Block's children in order (since its an
OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.
For HybridBlocks, the cached_graph is saved (Symbol & inputs) if
it has already been hybridized.
Parameters
----------
prefix : str
The prefix to use in filenames for saving this model:
<prefix>-model.json and <prefix>-model.params
"""
# create empty model structure
model = {}
def _save_cached_graphs(blk, structure, index=0):
# create new entry for this block
mdl = {}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index)] = mdl
index += 1
if isinstance(blk, HybridBlock):
if blk._cached_graph:
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
mdl['hybridized'] = True
else:
mdl['hybridized'] = False
# save param uuids
pmap = {}
mdl['params'] = pmap
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params[p]
pmap[p] = param._uuid
# recursively save children
for child in blk._children.values():
index = _save_cached_graphs(child(), mdl, index)
# return latest index (ie. block count)
return index
# save top-level block
_save_cached_graphs(self, model)
# save model
with open(prefix+'-model.json', 'w') as fp:
json.dump(model, fp)
# save params
self.save_parameters('MyModel-model.params')
def load(self, prefix):
"""Load a model saved using the `save` API
Reconfigures a model using the saved configuration. This function
does not regenerate the model architecture. It resets each Block's
parameter UUIDs as they were when saved in order to match the names of the
saved parameters.
This function assumes the Blocks in the model were created in the same
order they were when the model was saved. This is because each Block is
uniquely identified by Block class name and a unique ID in order (since
its an OrderedDict) and uses the unique ID to denote that specific Block.
Assumes that the model is created in an identical order every time.
If the model is not able to be recreated deterministically do not
use this set of APIs to save/load your model.
For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
restored if it had been hybridized before saving.
Parameters
----------
prefix : str
The prefix to use in filenames for loading this model:
<prefix>-model.json and <prefix>-model.params
"""
# load model json from file
with open(prefix+'-model.json') as fp:
model = json.load(fp)
def _load_cached_graphs(blk, structure, index=0):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = structure[name+str(index)]
index += 1
if isinstance(blk, HybridBlock):
if mdl['hybridized']:
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = fromjson(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(fromjson(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# reload param uuids
pmap = mdl['params']
for p, uuid in pmap.items():
param = blk.params[p]
param._uuid = uuid
# recursively reload children
for child in blk._children.values():
index = _load_cached_graphs(child(), mdl, index)
# return latest index (ie. block count)
return index
# load top-level block
_load_cached_graphs(self, model)
# load params
self.load_parameters('MyModel-model.params')
def hybridize(self, active=True, **kwargs):
""" Please refer description of HybridBlock hybridize().
"""
for cld in self._children.values():
cld().hybridize(active, **kwargs)
def cast(self, dtype):
"""Cast this Block to use another data type.
Parameters
----------
dtype : str or numpy.dtype
The new data type.
"""
for child in self._children.values():
child().cast(dtype)
for _, param in self.params.items():
param.cast(dtype)
def zero_grad(self):
"""Sets all Parameters' gradient buffer to 0."""
# collect gradient arrays for each device
arrays = defaultdict(list)
params = self.collect_params()
for p in params.values():
if p.grad_req == 'null' or p._grad is None:
continue
for g in p.list_grad():
if g.stype == 'row_sparse':
ndarray.zeros_like(g, out=g)
else:
if is_np_array():
arrays[g.device].append(g.as_nd_ndarray())
else:
arrays[g.device].append(g)
if len(arrays) == 0:
return
for arr in arrays.values():
ndarray.reset_arrays(*arr, num_arrays=len(arr))
def reset_device(self, device):
"""Re-assign all Parameters to other devices.
Parameters
----------
device : Device or list of Device, default :py:meth:`device.current_device()`.
Assign Parameter to given device. If device is a list of Device, a
copy will be made for each device.
"""
params = self.collect_params()
for i in params.values():
i.reset_device(device)
def reset_ctx(self, ctx):
"""This function has been deprecated. Please refer to ``Block.reset_device``."""
warnings.warn('Block.reset_ctx has been renamed to'
' Block.reset_device', DeprecationWarning)
self.reset_device(ctx)
def setattr(self, name, value):
"""Set an attribute to a new value for all Parameters.
For example, set grad_req to null if you don't need gradient w.r.t a
model's Parameters::
model.setattr('grad_req', 'null')
or change the learning rate multiplier::
model.setattr('lr_mult', 0.5)
Parameters
----------
name : str
Name of the attribute.
value : valid type for attribute name
The new value for the attribute.
"""
params = self.collect_params()
for i in params.values():
setattr(i, name, value)
def share_parameters(self, shared):
"""Share parameters recursively inside the model.
For example, if you want ``dense1`` to share ``dense0``'s weights, you can do::
dense0 = nn.Dense(20)
dense1 = nn.Dense(20)
dense1.share_parameters(dense0.collect_params())
which equals to
dense1.weight = dense0.weight
dense1.bias = dense0.bias
Note that unlike the `load_parameters` or `load_dict` functions,
`share_parameters` results in the `Parameter` object being shared (or
tied) between the models, whereas `load_parameters` or `load_dict` only
set the value of the data dictionary of a model. If you call
`load_parameters` or `load_dict` after `share_parameters`, the loaded
value will be reflected in all networks that use the shared (or tied)
`Parameter` object.
Parameters
----------
shared : Dict
Dict of the shared parameters.
Returns
-------
this block
"""
if shared is None:
return self
if not isinstance(shared, (dict, OrderedDict)):
raise ValueError("'shared' should be in type of Dict. Get type {}!".format(type(shared)))
shared_set = set(shared.keys())
self._shared_parameters(shared, shared_set)
if len(shared_set) > 0:
for name in shared_set:
warnings.warn("Parameter name {} is not in the current model!".format(name))
return self
def _shared_parameters(self, shared, shared_set, prefix=""):
if prefix:
prefix += '.'
for name in self._reg_params:
key = prefix + name
if shared.get(key) is not None:
setattr(self, name, shared[key])
shared_set.remove(key)
for name, child in self._children.items():
child()._shared_parameters(shared, shared_set, prefix + name)
def __call__(self, *args):
"""Calls forward. Only accepts positional arguments."""
for hook in self._forward_pre_hooks.values():
hook(self, args)
out = self.forward(*args)
for hook in self._forward_hooks.values():
hook(self, args, out)
if _mx_npx.is_np_array():
_check_all_np_ndarrays(out)
return out
def forward(self, *args):
"""Overrides to implement forward computation using :py:class:`NDArray`. Only
accepts positional arguments.
Parameters
----------
*args : list of NDArray
Input tensors.
"""
# pylint: disable= invalid-name
raise NotImplementedError
def register_op_hook(self, callback, monitor_all=False):
"""Install callback monitor.
Parameters
----------
callback : function
Function called to inspect the values of the intermediate outputs
of blocks after hybridization. It takes 3 parameters:
name of the tensor being inspected (str)
name of the operator producing or consuming that tensor (str)
tensor being inspected (NDArray).
monitor_all : bool, default False
If True, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld().register_op_hook(callback, monitor_all)
def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
The network must have been initialized, and must not have been hybridized.
Parameters
----------
inputs : object
Any input that the model supports. For any tensor in the input, only
:class:`mxnet.ndarray.NDArray` is supported.
"""
summary = OrderedDict()
seen = set()
hooks = []
def _get_shape_str(args):
def flatten(args):
if not isinstance(args, (list, tuple)):
return [args], int(0)
flat = []
fmts = []
for i in args:
arg, fmt = flatten(i)
flat.extend(arg)
fmts.append(fmt)
return flat, fmts
def regroup(args, fmt):
if isinstance(fmt, int):
if fmt == 0:
return args[0], args[1:]
return args[:fmt], args[fmt:]
ret = []
for i in fmt:
res, args = regroup(args, i)
ret.append(res)
return ret, args
flat_args, fmts = flatten(args)
flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x
for x in flat_args]
shapes = regroup(flat_arg_shapes, fmts)[0]
if isinstance(shapes, list):
shape_str = str(shapes)[1:-1]
else:
shape_str = str(shapes)
return shape_str.replace('L', '')
def _register_summary_hook(block):
assert not isinstance(block, HybridBlock) or not block._active, \
'"{}" must not be hybridized to print summary.'.format(type(block).__name__)
def _summary_hook(block, _, outputs):
class_name = block.__class__.__name__
block_idx = len(summary) - 1
m_key = f'{class_name}-{block_idx+1}'
summary[m_key] = OrderedDict()
summary[m_key]['output_shape'] = _get_shape_str(outputs)
params = 0
summary[m_key]['trainable'] = 0
summary[m_key]['shared'] = 0
for p in block.params.values():
params += p.data().size
summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size
if p in seen:
summary[m_key]['shared'] += p.data().size
else:
seen.add(p)
summary[m_key]['n_params'] = params
from .nn.basic_layers import Sequential, HybridSequential
if not isinstance(block, (Sequential, HybridSequential)):
hooks.append(block.register_forward_hook(_summary_hook))
summary['Input'] = OrderedDict()
summary['Input']['output_shape'] = _get_shape_str(inputs)
summary['Input']['n_params'] = 0
summary['Input']['trainable'] = 0
summary['Input']['shared'] = 0
try:
self.apply(_register_summary_hook)
self(*inputs)
line_format = '{:>20} {:>42} {:>15}'
print('-'*80)
print(line_format.format('Layer (type)', 'Output Shape', 'Param #'))
print('='*80)
total_params = 0
trainable_params = 0
shared_params = 0
for layer in summary:
print(line_format.format(layer,
str(summary[layer]['output_shape']),
summary[layer]['n_params']))
total_params += summary[layer]['n_params']
trainable_params += summary[layer]['trainable']
shared_params += summary[layer]['shared']
print('='*80)
print('Parameters in forward computation graph, duplicate included')
print(' Total params: ' + str(total_params))
print(' Trainable params: ' + str(trainable_params))
print(' Non-trainable params: ' + str(total_params - trainable_params))
print('Shared params in forward computation graph: ' + str(shared_params))
print('Unique parameters in model: ' + str(total_params - shared_params))
print('-'*80)