This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
initializer.py
831 lines (729 loc) · 29.8 KB
/
initializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Weight initializer."""
import re
import logging
import warnings
import json
from math import sqrt
import numpy as np
from .base import string_types
from .ndarray import NDArray, load
from . import random
from . import registry
from . import ndarray
from . util import is_np_array
from . import numpy as _mx_np # pylint: disable=reimported
# inherit str for backward compatibility
class InitDesc(str):
"""
Descriptor for the initialization pattern.
Parameters
----------
name : str
Name of variable.
attrs : dict of str to str
Attributes of this variable taken from ``Symbol.attr_dict``.
global_init : Initializer
Global initializer to fallback to.
"""
def __new__(cls, name, attrs=None, global_init=None):
ret = super(InitDesc, cls).__new__(cls, name)
ret.attrs = attrs or {}
ret.global_init = global_init
return ret
class Initializer(object):
"""The base class of an initializer."""
def __init__(self, **kwargs):
self._kwargs = kwargs
self._verbose = False
self._print_func = None
def set_verbosity(self, verbose=False, print_func=None):
"""Switch on/off verbose mode
Parameters
----------
verbose : bool
switch on/off verbose mode
print_func : function
A function that computes statistics of initialized arrays.
Takes an `NDArray` and returns an `str`. Defaults to mean
absolute value str((abs(x)/size(x)).asscalar()).
"""
self._verbose = verbose
if print_func is None:
def asum_stat(x):
"""returns |x|/size(x), async execution."""
return str((ndarray.norm(x)/sqrt(x.size)).asscalar())
print_func = asum_stat
self._print_func = print_func
return self
def _verbose_print(self, desc, init, arr):
"""Internal verbose print function
Parameters
----------
desc : InitDesc or str
name of the array
init : str
initializer pattern
arr : NDArray
initialized array
"""
if self._verbose and self._print_func:
logging.info('Initialized %s as %s: %s', desc, init, self._print_func(arr))
def dumps(self):
"""Saves the initializer to string
Returns
-------
str
JSON formatted string that describes the initializer.
Examples
--------
>>> # Create initializer and retrieve its parameters
...
>>> init = mx.init.Normal(0.5)
>>> init.dumps()
'["normal", {"sigma": 0.5}]'
>>> init = mx.init.Xavier(factor_type="in", magnitude=2.34)
>>> init.dumps()
'["xavier", {"rnd_type": "uniform", "magnitude": 2.34, "factor_type": "in"}]'
"""
return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, desc, arr):
"""Initialize an array
Parameters
----------
desc : InitDesc
Initialization pattern descriptor.
arr : NDArray
The array to be initialized.
"""
if not isinstance(desc, InitDesc):
self._legacy_init(desc, arr)
return
if desc.global_init is None:
desc.global_init = self
init = desc.attrs.get('__init__', "")
if init:
# when calling Variable initializer
create(init)._init_weight(desc, arr)
self._verbose_print(desc, init, arr)
else:
# register nnvm::FSetInputVariableAttrs in the backend for new patterns
# don't add new cases here.
if desc.endswith('weight'):
self._init_weight(desc, arr)
self._verbose_print(desc, 'weight', arr)
elif desc.endswith('bias'):
self._init_bias(desc, arr)
self._verbose_print(desc, 'bias', arr)
elif desc.endswith('gamma'):
self._init_gamma(desc, arr)
self._verbose_print(desc, 'gamma', arr)
elif desc.endswith('beta'):
self._init_beta(desc, arr)
self._verbose_print(desc, 'beta', arr)
elif desc.endswith('min'):
self._init_zero(desc, arr)
self._verbose_print(desc, 'min', arr)
elif desc.endswith('max'):
self._init_one(desc, arr)
self._verbose_print(desc, 'max', arr)
elif desc.endswith('weight_quantize'):
self._init_quantized_weight(desc, arr)
self._verbose_print(desc, 'weight_quantize', arr)
elif desc.endswith('bias_quantize'):
self._init_quantized_bias(desc, arr)
self._verbose_print(desc, 'bias_quantize', arr)
else:
self._init_default(desc, arr)
def _legacy_init(self, name, arr):
"""Legacy initialization method.
Parameters
----------
name : str
Name of corresponding NDArray.
arr : NDArray
NDArray to be initialized.
"""
warnings.warn(
"\033[91mCalling initializer with init(str, NDArray) has been deprecated." \
"please use init(mx.init.InitDesc(...), NDArray) instead.\033[0m",
DeprecationWarning, stacklevel=3)
if not isinstance(name, string_types):
raise TypeError('name must be string')
if not isinstance(arr, NDArray):
raise TypeError('arr must be NDArray')
if name.startswith('upsampling'):
self._init_bilinear(name, arr)
elif name.startswith('stn_loc') and name.endswith('weight'):
self._init_zero(name, arr)
elif name.startswith('stn_loc') and name.endswith('bias'):
self._init_loc_bias(name, arr)
elif name.endswith('bias'):
self._init_bias(name, arr)
elif name.endswith('gamma'):
self._init_gamma(name, arr)
elif name.endswith('beta'):
self._init_beta(name, arr)
elif name.endswith('weight'):
self._init_weight(name, arr)
elif name.endswith("moving_mean"):
self._init_zero(name, arr)
elif name.endswith("moving_var"):
self._init_one(name, arr)
elif name.endswith("moving_inv_var"):
self._init_zero(name, arr)
elif name.endswith("moving_avg"):
self._init_zero(name, arr)
elif name.endswith('min'):
self._init_zero(name, arr)
elif name.endswith('max'):
self._init_one(name, arr)
else:
self._init_default(name, arr)
def _init_bilinear(self, _, arr):
weight = np.zeros(np.prod(arr.shape), dtype='float32')
shape = arr.shape
f = np.ceil(shape[3] / 2.)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i // shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)
def _init_loc_bias(self, _, arr):
shape = arr.shape
assert(shape[0] == 6)
arr[:] = np.array([1.0, 0, 0, 0, 1.0, 0])
def _init_zero(self, _, arr):
arr[:] = 0.0
def _init_one(self, _, arr):
arr[:] = 1.0
def _init_bias(self, _, arr):
arr[:] = 0.0
def _init_quantized_bias(self, _, arr):
arr[:] = 0
def _init_gamma(self, _, arr):
arr[:] = 1.0
def _init_beta(self, _, arr):
arr[:] = 0.0
def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")
def _init_quantized_weight(self, _, arr):
_arr = random.randint(-127, 127, dtype='int32').asnumpy()
arr[:] = np.int8(_arr)
def _init_default(self, name, _):
raise ValueError(
f'Unknown initialization pattern for {name}. ' \
'Default initialization is now limited to '\
'"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \
'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern')
def __eq__(self, other):
if not isinstance(other, Initializer):
return NotImplemented
# pylint: disable=unidiomatic-typecheck
return type(self) is type(other) and self._kwargs == other._kwargs
# pylint: disable=invalid-name
_register = registry.get_register_func(Initializer, 'initializer')
alias = registry.get_alias_func(Initializer, 'initializer')
create = registry.get_create_func(Initializer, 'initializer')
# pylint: enable=invalid-name
def register(klass):
"""Registers a custom initializer.
Custom initializers can be created by extending `mx.init.Initializer` and implementing the
required functions like `_init_weight` and `_init_bias`. The created initializer must be
registered using `mx.init.register` before it can be called by name.
Parameters
----------
klass : class
A subclass of `mx.init.Initializer` that needs to be registered as a custom initializer.
Example
-------
>>> # Create and register a custom initializer that
... # initializes weights to 0.1 and biases to 1.
...
>>> @mx.init.register
... @alias('myinit')
... class CustomInit(mx.init.Initializer):
... def __init__(self):
... super(CustomInit, self).__init__()
... def _init_weight(self, _, arr):
... arr[:] = 0.1
... def _init_bias(self, _, arr):
... arr[:] = 1
...
>>> # block is an instance of 'mxnet.gluon.Block'
...
>>> block.initialize(CustomInit())
"""
return _register(klass)
class Load(object):
"""Initializes variables by loading data from file or dict.
**Note** Load will drop ``arg:`` or ``aux:`` from name and
initialize the variables that match with the prefix dropped.
Parameters
----------
param: str or dict of str->`NDArray`
Parameter file or dict mapping name to NDArray.
default_init: Initializer
Default initializer when name is not found in `param`.
verbose: bool
Flag for enabling logging of source when initializing.
"""
def __init__(self, param, default_init=None, verbose=False):
if isinstance(param, str):
param = load(param)
assert isinstance(param, dict)
self.param = {}
for name, arr in param.items():
if name.startswith('arg:') or name.startswith('aux:'):
self.param[name[4:]] = arr
else:
self.param[name] = arr
self.default_init = default_init
self.verbose = verbose
def __call__(self, name, arr):
if name in self.param:
assert arr.shape == self.param[name].shape, \
f'Parameter {name} cannot be initialized from loading. ' + \
f'Shape mismatch, target {str(arr.shape)} vs loaded {self.param[name].shape}'
arr[:] = self.param[name]
if self.verbose:
logging.info('Initialized %s by loading', name)
else:
assert self.default_init is not None, \
f"Cannot Initialize {name}. Not found in loaded param " + \
"and no default Initializer is provided."
self.default_init(name, arr)
if self.verbose:
logging.info('Initialized %s by default', name)
class Mixed(object):
"""Initialize parameters using multiple initializers.
Parameters
----------
patterns: list of str
List of regular expressions matching parameter names.
initializers: list of Initializer
List of initializers corresponding to `patterns`.
Example
-------
>>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize biases to zero
... # and every other parameter to random values with uniform distribution.
...
>>> init = mx.initializer.Mixed(['bias', '.*'], [mx.init.Zero(), mx.init.Uniform(0.1)])
>>> block.initialize(init)
>>>
>>> for dictionary in module.get_params():
... for key in dictionary:
... print(key)
... print(dictionary[key].asnumpy())
...
fullyconnected1_weight
[[ 0.0097627 0.01856892 0.04303787]]
fullyconnected1_bias
[ 0.]
"""
def __init__(self, patterns, initializers):
assert len(patterns) == len(initializers)
self.map = list(zip([re.compile(p) for p in patterns], initializers))
def __call__(self, name, arr):
for prog, init in self.map:
if prog.match(name):
init(name, arr)
return
raise ValueError('Parameter name %s did not match any pattern. Consider' +
'add a ".*" pattern at the and with default Initializer.')
@register
@alias("zeros")
class Zero(Initializer):
"""Initializes weights to zero.
Example
-------
>>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights to zero.
...
>>> init = mx.initializer.Zero()
>>> module.initialize(init)
>>> for dictionary in module.get_params():
... for key in dictionary:
... print(key)
... print(dictionary[key].asnumpy())
...
fullyconnected0_weight
[[ 0. 0. 0.]]
"""
def __init__(self):
super(Zero, self).__init__()
def _init_weight(self, _, arr):
arr[:] = 0
@register
@alias("ones")
class One(Initializer):
"""Initializes weights to one.
Example
-------
>>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights to one.
...
>>> init = mx.initializer.One()
>>> module.initialize(init)
>>> for dictionary in module.get_params():
... for key in dictionary:
... print(key)
... print(dictionary[key].asnumpy())
...
fullyconnected0_weight
[[ 1. 1. 1.]]
"""
def __init__(self):
super(One, self).__init__()
def _init_weight(self, _, arr):
arr[:] = 1
@register
class Constant(Initializer):
"""Initializes the weights to a given value.
The value passed in can be a scalar or a NDarray that matches the shape
of the parameter to be set.
Parameters
----------
value : float, NDArray
Value to set.
"""
def __init__(self, value):
super(Constant, self).__init__(value=value)
self.value = value
def _init_weight(self, _, arr):
arr[:] = self.value
def dumps(self):
val = self._kwargs['value']
if not np.isscalar(val):
self._kwargs['value'] = val.tolist() if isinstance(val, np.ndarray) else val.asnumpy().tolist()
return json.dumps([self.__class__.__name__.lower(), self._kwargs])
@register
class Uniform(Initializer):
"""Initializes weights with random values uniformly sampled from a given range.
Parameters
----------
scale : float, optional
The bound on the range of the generated random values.
Values are generated from the range [-`scale`, `scale`].
Default scale is 0.07.
Example
-------
>>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights
>>> # to random values uniformly sampled between -0.1 and 0.1.
...
>>> init = mx.init.Uniform(0.1)
>>> module.initialize(init)
>>> for dictionary in module.get_params():
... for key in dictionary:
... print(key)
... print(dictionary[key].asnumpy())
...
fullyconnected0_weight
[[ 0.01360891 -0.02144304 0.08511933]]
"""
def __init__(self, scale=0.07):
super(Uniform, self).__init__(scale=scale)
self.scale = scale
def _init_weight(self, _, arr):
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
uniform_fn(-self.scale, self.scale, arr.shape, dtype=arr.dtype, out=arr)
@register
class Normal(Initializer):
"""Initializes weights with random values sampled from a normal distribution
with a mean of zero and standard deviation of `sigma`.
Parameters
----------
sigma : float, optional
Standard deviation of the normal distribution.
Default standard deviation is 0.01.
Example
-------
>>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights
>>> # to random values sampled from a normal distribution.
...
>>> init = mx.init.Normal(0.5)
>>> module.initialize(init)
>>> for dictionary in module.get_params():
... for key in dictionary:
... print(key)
... print(dictionary[key].asnumpy())
...
fullyconnected0_weight
[[-0.3214761 -0.12660924 0.53789419]]
"""
def __init__(self, sigma=0.01):
super(Normal, self).__init__(sigma=sigma)
self.sigma = sigma
def _init_weight(self, _, arr):
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
normal_fn(0, self.sigma, arr.shape, dtype=arr.dtype, out=arr)
@register
class Orthogonal(Initializer):
"""Initialize weight as orthogonal matrix.
This initializer implements *Exact solutions to the nonlinear dynamics of
learning in deep linear neural networks*, available at
https://arxiv.org/abs/1312.6120.
Parameters
----------
scale : float optional
Scaling factor of weight.
rand_type: string optional
Use "uniform" or "normal" random number to initialize weight.
"""
def __init__(self, scale=1.414, rand_type="uniform"):
super(Orthogonal, self).__init__(scale=scale, rand_type=rand_type)
self.scale = scale
self.rand_type = rand_type
def _init_weight(self, _, arr):
nout = arr.shape[0]
nin = np.prod(arr.shape[1:])
if self.rand_type == "uniform":
tmp = random.uniform(-1.0, 1.0, shape=(nout, nin)).asnumpy()
elif self.rand_type == "normal":
tmp = random.normal(0.0, 1.0, shape=(nout, nin)).asnumpy()
u, _, v = np.linalg.svd(tmp, full_matrices=False) # pylint: disable=invalid-name
if u.shape == tmp.shape:
res = u
else:
res = v
res = self.scale * res.reshape(arr.shape)
arr[:] = res
@register
class Xavier(Initializer):
"""Returns an initializer performing "Xavier" initialization for weights.
This initializer is designed to keep the scale of gradients roughly the same
in all layers.
By default, `rnd_type` is ``'uniform'`` and `factor_type` is ``'avg'``,
the initializer fills the weights with random numbers in the range
of :math:`[-c, c]`, where :math:`c = \\sqrt{\\frac{3.}{0.5 * (n_{in} + n_{out})}}`.
:math:`n_{in}` is the number of neurons feeding into weights, and :math:`n_{out}` is
the number of neurons the result is fed to.
If `rnd_type` is ``'uniform'`` and `factor_type` is ``'in'``,
the :math:`c = \\sqrt{\\frac{3.}{n_{in}}}`.
Similarly when `factor_type` is ``'out'``, the :math:`c = \\sqrt{\\frac{3.}{n_{out}}}`.
If `rnd_type` is ``'gaussian'`` and `factor_type` is ``'avg'``,
the initializer fills the weights with numbers from normal distribution with
a standard deviation of :math:`\\sqrt{\\frac{3.}{0.5 * (n_{in} + n_{out})}}`.
Parameters
----------
rnd_type: str, optional
Random generator type, can be ``'gaussian'`` or ``'uniform'``.
factor_type: str, optional
Can be ``'avg'``, ``'in'``, or ``'out'``.
magnitude: float, optional
Scale of random number.
"""
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
super(Xavier, self).__init__(rnd_type=rnd_type, factor_type=factor_type,
magnitude=magnitude)
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = float(magnitude)
def _init_weight(self, name, arr):
shape = arr.shape
hw_scale = 1.
if len(shape) < 2:
raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at'
' least 2D.'.format(name))
if len(shape) > 2:
hw_scale = np.prod(shape[2:])
fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale
factor = 1.
if self.factor_type == "avg":
factor = (fan_in + fan_out) / 2.0
elif self.factor_type == "in":
factor = fan_in
elif self.factor_type == "out":
factor = fan_out
else:
raise ValueError("Incorrect factor type")
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
uniform_fn(-scale, scale, arr.shape, dtype=arr.dtype, out=arr)
elif self.rnd_type == "gaussian":
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
normal_fn(0, scale, arr.shape, dtype=arr.dtype, out=arr)
else:
raise ValueError("Unknown random type")
@register
class MSRAPrelu(Xavier):
"""Initialize the weight according to a MSRA paper.
This initializer implements *Delving Deep into Rectifiers: Surpassing
Human-Level Performance on ImageNet Classification*, available at
https://arxiv.org/abs/1502.01852.
This initializer is proposed for initialization related to ReLu activation,
it makes some changes on top of Xavier method.
Parameters
----------
factor_type: str, optional
Can be ``'avg'``, ``'in'``, or ``'out'``.
slope: float, optional
initial slope of any PReLU (or similar) nonlinearities.
"""
def __init__(self, factor_type="avg", slope=0.25):
magnitude = 2. / (1 + slope ** 2)
super(MSRAPrelu, self).__init__("gaussian", factor_type, magnitude)
self._kwargs = {'factor_type': factor_type, 'slope': slope}
@register
class Bilinear(Initializer):
"""Initialize weight for upsampling layers."""
def __init__(self):
super(Bilinear, self).__init__()
def _init_weight(self, _, arr):
weight = np.zeros(np.prod(arr.shape), dtype='float32')
shape = arr.shape
f = np.ceil(shape[3] / 2.)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i // shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)
@register
class LSTMBias(Initializer):
"""Initialize all biases of an LSTMCell to 0.0 except for
the forget gate whose bias is set to custom value.
Parameters
----------
forget_bias: float, default 1.0
bias for the forget gate. Jozefowicz et al. 2015 recommends
setting this to 1.0.
"""
def __init__(self, forget_bias=1.0):
super(LSTMBias, self).__init__(forget_bias=forget_bias)
self.forget_bias = forget_bias
def _init_weight(self, name, arr):
arr[:] = 0.0
# in the case of LSTMCell the forget gate is the second
# gate of the 4 LSTM gates, we modify the according values.
num_hidden = int(arr.shape[0] / 4)
arr[num_hidden:2*num_hidden] = self.forget_bias
@register
class RNNFused(Initializer):
"""Initialize RNN fused parameter with bias part initialized to 0.0 and
weight initialized with random values uniformly sampled from a given range.
Parameters
----------
mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required
the type of RNN to compute
num_layers : int (non-negative), required
number of stacked layers
state_size : int (non-negative), required
size of the state for each layer
bidirectional : boolean, optional, default=0
whether to use bidirectional recurrent layers
projection_size : int or None, optional, default='None'
size of project size
scale : float, optional
The bound on the range of the generated random values for weights.
Values are generated from the range [-`scale`, `scale`].
Default scale is 0.07.
"""
def __init__(self, mode, num_layers, state_size, bidirectional=False,
projection_size=None, i2h_weight_initializer=None,
h2h_weight_initializer=None, i2h_bias_initializer=None,
h2h_bias_initializer=None, h2r_weight_initializer=None):
super(RNNFused, self).__init__(mode=mode, num_layers=num_layers,
state_size=state_size,
bidirectional=bidirectional,
projection_size=projection_size,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
h2r_weight_initializer=h2r_weight_initializer)
self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
self.num_layers = num_layers
self.num_hidden = state_size
self.dir = 2 if bidirectional else 1
self.projection_size = projection_size
self._i2h_weight_initializer = i2h_weight_initializer
self._h2h_weight_initializer = h2h_weight_initializer
self._i2h_bias_initializer = i2h_bias_initializer
self._h2h_bias_initializer = h2h_bias_initializer
self._h2r_weight_initializer = h2r_weight_initializer
# pylint: disable=too-many-nested-blocks
def _init_weight(self, name, arr):
arr_len = arr.shape[0]
size = self.num_hidden * self.dir * self.gates
if not self.projection_size:
# second layer size
size2 = (self.num_hidden * self.dir + self.num_hidden + 2) * size
input_size = (arr_len - (self.num_layers - 1) * size2) // \
size - 2 - self.num_hidden
else:
# second layer size
size2 = (self.projection_size * self.dir + self.projection_size + 2) * size
size_projection = self.projection_size * self.num_hidden * self.num_layers * self.dir
input_size = (arr_len - size_projection - (self.num_layers - 1) * size2) // \
size - 2 - self.projection_size
begin = 0
if not self.projection_size:
for param in ['weight', 'bias']:
for layer_num in range(self.num_layers):
for _ in range(self.dir):
for connect in ['i2h', 'h2h']:
num_inputs = input_size
if layer_num != 0:
num_inputs = self.num_hidden * self.dir
if connect == 'h2h':
num_inputs = self.num_hidden
shape0 = self.gates * self.num_hidden
if param == 'weight':
cur_len = shape0 * num_inputs
else:
cur_len = shape0
self._init_util(param, connect, arr[begin:begin+cur_len])
begin += cur_len
else:
for param in ['weight', 'bias']:
for layer_num in range(self.num_layers):
for _ in range(self.dir):
for connect in ['i2h', 'h2h', 'h2r']:
if connect != 'h2r' or param != 'bias':
if connect == 'h2r':
cur_len = self.projection_size * self.num_hidden
else:
num_inputs = input_size
if layer_num != 0:
num_inputs = self.projection_size * self.dir
if connect == 'h2h':
num_inputs = self.projection_size
shape0 = self.gates * self.num_hidden
if param == 'weight':
cur_len = shape0 * num_inputs
else:
cur_len = shape0
self._init_util(param, connect, arr[begin:begin+cur_len])
begin += cur_len
def _init_util(self, param, connect, arr):
name = "_{}_{}_initializer".format(connect, param)
init = getattr(self, name)
create(init)(InitDesc(name, {'__init__': init}), arr)
def set_initializer(self, init):
self._i2h_weight_initializer = \
init if not self._i2h_weight_initializer else 'uniform'
self._h2h_weight_initializer = \
init if not self._h2h_weight_initializer else 'uniform'
self._i2h_bias_initializer = \
init if not self._i2h_bias_initializer else 'zero'
self._h2h_bias_initializer = \
init if not self._i2h_bias_initializer else 'zero'
self._h2r_weight_initializer = \
init if not self._h2r_weight_initializer else 'uniform'