forked from akjindal53244/dependency_parsing_tf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparams_init.py
21 lines (14 loc) · 842 Bytes
/
params_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import tensorflow as tf
import math
def random_uniform_initializer(shape, name, val, trainable=True):
out = tf.get_variable(shape=list(shape), dtype=tf.float32,
initializer=tf.random_uniform_initializer(minval=-val, maxval=val, dtype=tf.float32),
trainable=trainable, name=name)
return out
def xavier_initializer(shape, name, trainable=True):
val = math.sqrt(6. / sum(shape))
return random_uniform_initializer(shape, name, val, trainable=trainable)
def random_normal_initializer(shape, name, mean=0., stddev=1, trainable=True):
return tf.get_variable(shape = list(shape), dtype=tf.float32,
initializer=tf.random_normal(shape=shape, mean=mean, stddev=stddev, dtype=tf.float32),
trainable=trainable, name=name)