Skip to content

Commit bfbc881

Browse files
authored
Fix Regularization (google#21)
Regularization is now being applied. I set it conservatively to 1e-8 by default.
1 parent 180da33 commit bfbc881

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

frame_level_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def create_model(self, model_input, vocab_size, num_frames, **unused_params):
7777

7878
output = slim.fully_connected(
7979
avg_pooled, vocab_size, activation_fn=tf.nn.sigmoid,
80-
weights_regularizer=slim.l2_regularizer(0.01))
80+
weights_regularizer=slim.l2_regularizer(1e-8))
8181
return {"predictions": output}
8282

8383
class DBoFModel(models.BaseModel):

train.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"label_loss", "CrossEntropyLoss",
6969
"Which loss function to use for training the model.")
7070
flags.DEFINE_float(
71-
"regularization_penalty", 1e-3,
71+
"regularization_penalty", 1,
7272
"How much weight to give to the regularization loss (the label loss has "
7373
"a weight of 1).")
7474
flags.DEFINE_float("base_learning_rate", 0.01,
@@ -172,7 +172,7 @@ def build_graph(reader,
172172
batch_size=1000,
173173
base_learning_rate=0.01,
174174
optimizer_class=tf.train.AdamOptimizer,
175-
regularization_penalty=1e-3,
175+
regularization_penalty=1,
176176
num_readers=1,
177177
num_epochs=None):
178178
"""Creates the Tensorflow graph.
@@ -234,6 +234,9 @@ def build_graph(reader,
234234
reg_loss = result["regularization_loss"]
235235
else:
236236
reg_loss = tf.constant(0.0)
237+
reg_losses = tf.losses.get_regularization_losses()
238+
if reg_losses:
239+
reg_loss += tf.add_n(reg_losses)
237240
if regularization_penalty != 0:
238241
tf.summary.scalar("reg_loss", reg_loss)
239242

video_level_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
class LogisticModel(models.BaseModel):
3131
"""Logistic model with L2 regularization."""
3232

33-
def create_model(self, model_input, vocab_size, **unused_params):
33+
def create_model(self, model_input, vocab_size, l2_penalty=1e-8, **unused_params):
3434
"""Creates a logistic model.
3535
3636
Args:
@@ -43,7 +43,7 @@ def create_model(self, model_input, vocab_size, **unused_params):
4343
batch_size x num_classes."""
4444
output = slim.fully_connected(
4545
model_input, vocab_size, activation_fn=tf.nn.sigmoid,
46-
weights_regularizer=slim.l2_regularizer(0.01))
46+
weights_regularizer=slim.l2_regularizer(l2_penalty))
4747
return {"predictions": output}
4848

4949
class MoeModel(models.BaseModel):
@@ -53,7 +53,7 @@ def create_model(self,
5353
model_input,
5454
vocab_size,
5555
num_mixtures=None,
56-
l2_penalty=1e-5,
56+
l2_penalty=1e-8,
5757
**unused_params):
5858
"""Creates a Mixture of (Logistic) Experts model.
5959

0 commit comments

Comments
 (0)