Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 488bdaa

Browse files
joaogui1copybara-github
authored andcommitted
Merge of PR #1666
PiperOrigin-RevId: 265800010
1 parent 76803de commit 488bdaa

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

tensor2tensor/trax/layers/initializers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
1516
"""Trax initializers."""
1617

1718
from __future__ import absolute_import
@@ -22,8 +23,9 @@
2223
from tensor2tensor.trax import backend
2324

2425

25-
def _get_fans(shape, out_dim=-1, in_dim=-2):
26-
#temporary fix until numpy.delete supports negative indices
26+
def _GetFans(shape, out_dim=-1, in_dim=-2):
27+
"""Get the fan-in and fan-out sizes for the given shape and dims."""
28+
# Temporary fix until numpy.delete supports negative indices.
2729
if out_dim < 0:
2830
out_dim += len(shape)
2931
if in_dim < 0:
@@ -33,9 +35,11 @@ def _get_fans(shape, out_dim=-1, in_dim=-2):
3335
if len(shape) >= 2:
3436
fan_in, fan_out = shape[in_dim], shape[out_dim]
3537
elif len(shape) == 1:
36-
fan_in, fan_out = shape[0]
38+
fan_in = shape[0]
39+
fan_out = shape[0]
3740
else:
38-
fan_in, fan_out = 1.
41+
fan_in = 1.
42+
fan_out = 1.
3943
fan_in *= receptive_field
4044
fan_out *= receptive_field
4145
return fan_in, fan_out
@@ -61,7 +65,7 @@ def Init(shape, rng):
6165

6266

6367
def VarianceScalingInitializer(out_dim, in_dim, scale, mode, distribution):
64-
"""Initializer capable of adapting its scale to the shape of weights tensors."""
68+
"""Initializer capable of adapting its scale to the shape of weights."""
6569
if scale <= 0.:
6670
raise ValueError('scale must be positive float, {} given'.format(scale))
6771
if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
@@ -70,7 +74,8 @@ def VarianceScalingInitializer(out_dim, in_dim, scale, mode, distribution):
7074
.format(mode))
7175

7276
def Init(shape, rng):
73-
fan_in, fan_out = _get_fans(shape, out_dim, in_dim)
77+
"""The initializer function."""
78+
fan_in, fan_out = _GetFans(shape, out_dim, in_dim)
7479
gain = scale
7580
if mode == 'fan_in':
7681
gain /= fan_in

tensor2tensor/trax/layers/initializers_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_random_normal(self):
3131
init_value = initializer(input_shape, random.get_prng(0))
3232
self.assertEqual(tuple(init_value.shape), input_shape)
3333

34-
3534
def test_random_uniform(self):
3635
initializer = initializers.RandomUniformInitializer()
3736
input_shape = (29, 5, 7, 20)
@@ -44,35 +43,30 @@ def test_glorot_normal(self):
4443
init_value = initializer(input_shape, random.get_prng(0))
4544
self.assertEqual(tuple(init_value.shape), input_shape)
4645

47-
4846
def test_glorot_uniform(self):
4947
initializer = initializers.GlorotUniformInitializer()
5048
input_shape = (29, 5, 7, 20)
5149
init_value = initializer(input_shape, random.get_prng(0))
5250
self.assertEqual(tuple(init_value.shape), input_shape)
5351

54-
5552
def test_lecun_normal(self):
5653
initializer = initializers.LeCunNormalInitializer()
5754
input_shape = (29, 5, 7, 20)
5855
init_value = initializer(input_shape, random.get_prng(0))
5956
self.assertEqual(tuple(init_value.shape), input_shape)
6057

61-
6258
def test_lecun_uniform(self):
6359
initializer = initializers.LeCunUniformInitializer()
6460
input_shape = (29, 5, 7, 20)
6561
init_value = initializer(input_shape, random.get_prng(0))
6662
self.assertEqual(tuple(init_value.shape), input_shape)
6763

68-
6964
def test_kaiming_normal(self):
7065
initializer = initializers.KaimingNormalInitializer()
7166
input_shape = (29, 5, 7, 20)
7267
init_value = initializer(input_shape, random.get_prng(0))
7368
self.assertEqual(tuple(init_value.shape), input_shape)
7469

75-
7670
def test_kaiming_uniform(self):
7771
initializer = initializers.KaimingUniformInitializer()
7872
input_shape = (29, 5, 7, 20)

0 commit comments

Comments
 (0)