Skip to content

Commit 775593a

Browse files
jackdkpe
authored andcommitted
Registry refactor (tensorflow#1410)
* registry refactor and deprecated call-site updates * added on_problem_set callback, simplified name * changed optimizer registration names to snake_case, documentation * removed create_registry
1 parent a96a634 commit 775593a

22 files changed

+596
-470
lines changed

tensor2tensor/bin/t2t_attack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def create_attack_params():
7373

7474

7575
def create_attack(attack):
76-
return registry.attacks(attack)
76+
return registry.attack(attack)
7777

7878

7979
def create_surrogate_hparams():

tensor2tensor/bin/t2t_datagen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main(_):
147147

148148
# Calculate the list of problems to generate.
149149
problems = sorted(
150-
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
150+
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_base_problems())
151151
for exclude in FLAGS.exclude_problems.split(","):
152152
if exclude:
153153
problems = [p for p in problems if exclude not in p]
@@ -169,7 +169,8 @@ def main(_):
169169

170170
if not problems:
171171
problems_str = "\n * ".join(
172-
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
172+
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) +
173+
registry.list_base_problems()))
173174
error_msg = ("You must specify one of the supported problems to "
174175
"generate data for:\n * " + problems_str + "\n")
175176
error_msg += ("TIMIT and parsing need data_sets specified with "

tensor2tensor/bin/t2t_prune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def create_pruning_params():
5454

5555

5656
def create_pruning_strategy(name):
57-
return registry.pruning_strategies(name)
57+
return registry.pruning_strategy(name)
5858

5959

6060
def main(argv):

tensor2tensor/layers/common_hparams.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def basic_params1():
5555
initializer="orthogonal",
5656
initializer_gain=1.5,
5757
label_smoothing=0.1,
58-
optimizer="Adam",
58+
optimizer="adam",
5959
optimizer_adam_epsilon=1e-6,
6060
optimizer_adam_beta1=0.85,
6161
optimizer_adam_beta2=0.997,
@@ -466,7 +466,7 @@ def basic_range1(ranged_hparams):
466466
rhp.set_float("optimizer_adam_beta2", 0.995, 0.999)
467467
rhp.set_categorical(
468468
"optimizer",
469-
["Adam", "Adagrad", "Momentum", "RMSProp", "SGD", "YellowFin"])
469+
["adam", "adagrad", "momentum", "rms_prop", "sgd", "yellow_fin"])
470470

471471

472472
@registry.register_ranged_hparams

tensor2tensor/models/mtf_transformer2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def sample(self, features, mesh):
269269
return self.combine_batch_dims(ret)
270270

271271

272-
layers_registry = registry.create_registry("layers")
272+
layers_registry = registry.Registries.layers
273273

274274

275275
# The following functions construct layers based on hyperparmeters

tensor2tensor/models/research/adafactor_experiments.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ def mimic_adam_with_adafactor(hparams):
3030
Some minor things may be different, like epsilon and beta1 correction.
3131
3232
Args:
33-
hparams: model hyperparameters where "Adam" in hparams.optimizer
33+
hparams: model hyperparameters where "adam" in hparams.optimizer
3434
"""
35-
assert "Adam" in hparams.optimizer
36-
hparams.optimizer = "Adafactor"
35+
assert "adam" in hparams.optimizer
36+
hparams.optimizer = "adafactor"
3737
hparams.optimizer_adafactor_beta1 = hparams.optimizer_adam_beta1
3838
hparams.optimizer_adafactor_beta2 = hparams.optimizer_adam_beta2
3939
hparams.optimizer_adafactor_multiply_by_parameter_scale = False
4040
hparams.optimizer_adafactor_factored = False
4141
hparams.optimizer_adafactor_clipping_threshold = None
42-
hparams.optimizer_adafactor_decay_type = "Adam"
42+
hparams.optimizer_adafactor_decay_type = "adam"
4343

4444

4545
@registry.register_hparams
@@ -50,7 +50,7 @@ def afx_adam():
5050
hparams.optimizer_adam_beta2 = 0.999
5151
hparams.symbol_modality_num_shards = 1
5252
hparams.batch_size = 2048
53-
hparams.optimizer = "Adam"
53+
hparams.optimizer = "adam"
5454
hparams.learning_rate_schedule = (
5555
"constant*rsqrt_decay*linear_warmup*rsqrt_hidden_size")
5656
hparams.learning_rate_constant = 2.0

tensor2tensor/models/research/autoencoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def body(self, features):
10201020
def autoencoder_basic():
10211021
"""Basic autoencoder model."""
10221022
hparams = common_hparams.basic_params1()
1023-
hparams.optimizer = "Adam"
1023+
hparams.optimizer = "adam"
10241024
hparams.learning_rate_constant = 0.0002
10251025
hparams.learning_rate_warmup_steps = 500
10261026
hparams.learning_rate_schedule = "constant * linear_warmup"

tensor2tensor/models/research/transformer_nat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def transformer_nat_small():
392392
hparams.filter_size = 2048
393393
hparams.label_smoothing = 0.0
394394
hparams.force_full_predict = True
395-
hparams.optimizer = "Adam"
395+
hparams.optimizer = "adam"
396396
hparams.optimizer_adam_epsilon = 1e-9
397397
hparams.optimizer_adam_beta1 = 0.9
398398
hparams.optimizer_adam_beta2 = 0.997

tensor2tensor/models/research/transformer_vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def transformer_ae_small():
767767
hparams.filter_size = 2048
768768
hparams.add_hparam("compress_filter_size", 2048 * 2)
769769
hparams.label_smoothing = 0.0
770-
hparams.optimizer = "Adam" # Can be unstable, maybe try Adam.
770+
hparams.optimizer = "adam" # Can be unstable, maybe try Adam.
771771
hparams.optimizer_adam_epsilon = 1e-9
772772
hparams.optimizer_adam_beta1 = 0.9
773773
hparams.optimizer_adam_beta2 = 0.997 # Needs tuning, try 0.98 to 0.999.
@@ -941,7 +941,7 @@ def transformer_ae_a3():
941941
def transformer_ae_a6():
942942
"""Best hparams for transformer with semhash."""
943943
hparams = transformer_ae_a3()
944-
hparams.optimizer = "Adam"
944+
hparams.optimizer = "adam"
945945
hparams.noise_dev = 0.5
946946
return hparams
947947

tensor2tensor/models/research/vqa_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def vqa_attention_base():
335335
hparams = common_hparams.basic_params1()
336336
hparams.batch_size = 128
337337
hparams.use_fixed_batch_size = True,
338-
hparams.optimizer = "Adam"
338+
hparams.optimizer = "adam"
339339
hparams.optimizer_adam_beta1 = 0.9
340340
hparams.optimizer_adam_beta2 = 0.999
341341
hparams.optimizer_adam_epsilon = 1e-8

0 commit comments

Comments
 (0)