Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 288f46c

Browse files
Dustin TranCopybara-Service
authored andcommitted
Remove the remaining modality strings and combine hparams.{input_modalities,target_modality}.
PiperOrigin-RevId: 219617388
1 parent dae9e7b commit 288f46c

19 files changed

+87
-207
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from tensor2tensor.data_generators import generator_utils
3030
from tensor2tensor.data_generators import text_encoder
31-
from tensor2tensor.layers import modalities
3231
from tensor2tensor.utils import data_reader
3332
from tensor2tensor.utils import metrics
3433
from tensor2tensor.utils import mlperf_log
@@ -1148,33 +1147,12 @@ def _create_modalities(problem_hparams, hparams):
11481147
Returns:
11491148
None
11501149
"""
1151-
input_modality_overrides = {}
1152-
if hasattr(hparams, "input_modalities"):
1153-
for override_str in hparams.input_modalities.split(";"):
1154-
if override_str != "default":
1155-
parts = override_str.split(":")
1156-
feature_name = parts[0]
1157-
modality_name = ":".join(parts[1:])
1158-
input_modality_overrides[feature_name] = modality_name
1159-
1160-
target_modality_name = None
1161-
if (hasattr(hparams, "target_modality") and
1162-
hparams.target_modality != "default"):
1163-
target_modality_name = hparams.target_modality
1164-
1150+
modality_overrides = getattr(hparams, "modality", {})
11651151
modality = {}
11661152
for feature_name, modality_cls in six.iteritems(problem_hparams.modality):
11671153
vocab_size = problem_hparams.vocab_size[feature_name]
1168-
if feature_name in input_modality_overrides:
1169-
modality_obj = modalities.create_modality(
1170-
(input_modality_overrides[feature_name], vocab_size), hparams)
1171-
elif target_modality_name and feature_name == "targets":
1172-
# TODO(lukaszkaiser): allow overriding other target modalities.
1173-
modality_obj = modalities.create_modality(
1174-
(target_modality_name, vocab_size), hparams)
1175-
else:
1176-
modality_obj = modality_cls(hparams, vocab_size)
1177-
modality[feature_name] = modality_obj
1154+
modality_cls = modality_overrides.get(feature_name, modality_cls)
1155+
modality[feature_name] = modality_cls(hparams, vocab_size)
11781156
problem_hparams.modality = modality
11791157

11801158

@@ -1200,17 +1178,10 @@ def _default_hparams():
12001178
# token.
12011179
stop_at_eos=False,
12021180

1203-
# Modalities used to map from input features to a space compatible with
1204-
# chosen model architecture. One modality spec (which is a 2-tuple,
1205-
# (modality_full_name, vocab_size)) per feature key. modality_full_name
1206-
# is a string type:name, e.g. class_label:class_label_2d. Leaving off
1207-
# the name uses the default modality for that type (e.g. class_label ==
1208-
# class_label:default).
1209-
input_modality={},
1210-
1211-
# Modality used to map from hidden representation to the target space.
1212-
# Specified as a modality spec, a 2-tuple described above.
1213-
target_modality=None,
1181+
# Modalities used to map from features to a space compatible with
1182+
# chosen model architecture. It comprises key-value pairs of a feature
1183+
# name (str) and its modality class.
1184+
modality={},
12141185

12151186
# Identifiers used to tell the model which input/target space will be
12161187
# expected. For example, it can tell that we expect French as characters

tensor2tensor/layers/common_hparams.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,14 @@ def basic_params1():
162162
# embeddings and the target embeddings.
163163
# You can also share the input embeddings with the target embeddings
164164
# by using a problem_hparams that uses the same modality object for
165-
# the input_modality and target_modality.
165+
# the input modality and target modality.
166166
shared_embedding=False,
167167
# In SymbolModality, skip the top layer, assume we're providing logits.
168168
symbol_modality_skip_top=False,
169-
# For each feature for which you want to override the default input
170-
# modality, add an entry to this semicolon-separated string. Entries are
171-
# formatted "feature_name:modality_type:modality_name", e.g.
172-
# "inputs:symbol:default;other_inputs:audio:identity".
173-
input_modalities="default", # We don't use empty string in params.
174-
# To override the default target modality, specify
175-
# "modality_type:modality_name", e.g. "symbol:ctc".
176-
target_modality="default",
169+
# Modalities used to map from features to a space compatible with
170+
# chosen model architecture. It comprises key-value pairs of a feature
171+
# name (str) and its modality class.
172+
modality={},
177173
# The maximum length of "input" sequence.
178174
# Sequences longer than this value will be truncated. 0 or negative values
179175
# mean there is no maximum or truncation.

tensor2tensor/layers/common_image_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,14 @@ def prepare_image(inputs, hparams, name=None):
608608
channels = hparams.num_channels
609609

610610
hidden_size = hparams.hidden_size
611-
# Only do lookup if the modality is identity
612-
if hparams.target_modality == "image:identity":
611+
# TODO(trandustin): Check via modalities.IdentityModality and not its name.
612+
# The current implementation is to avoid circular imports, modalities ->
613+
# discretization -> common_image_attention -> modalities.
614+
if "targets" in hparams.modality:
615+
target_modality_name = hparams.modality["targets"].__name__
616+
else:
617+
target_modality_name = None
618+
if target_modality_name == "IdentityModality":
613619
inputs = tf.to_int32(inputs)
614620
x = get_channel_embeddings(channels, inputs, hidden_size, name=name)
615621
else:

tensor2tensor/layers/modalities.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from tensor2tensor.layers import common_video
2626
from tensor2tensor.layers import discretization
2727
from tensor2tensor.utils import modality
28-
from tensor2tensor.utils import registry
2928

3029
import tensorflow as tf
3130

@@ -1070,96 +1069,3 @@ def top(self, body_output, _):
10701069
x = body_output
10711070
x = tf.expand_dims(x[:, -1], 1) # Pick the last timestep
10721071
return tf.layers.dense(x, self._vocab_size)
1073-
1074-
1075-
def create_modality(modality_spec, model_hparams):
1076-
"""Creates modality.
1077-
1078-
Args:
1079-
modality_spec: tuple ("modality_type:modality_name", vocab_size).
1080-
model_hparams: tf.contrib.training.HParams.
1081-
1082-
Returns:
1083-
Modality.
1084-
1085-
Raises:
1086-
LookupError: if modality_type is not recognized. See registry.Modalities for
1087-
accepted types.
1088-
"""
1089-
modality_full_name, vocab_size = modality_spec
1090-
modality_type, modality_name = parse_modality_name(modality_full_name)
1091-
1092-
if modality_type == registry.Modalities.SYMBOL:
1093-
modality_collection = {
1094-
"default": SymbolModality,
1095-
"identity": IdentitySymbolModality,
1096-
"weights_all": SymbolModalityWeightsAll,
1097-
"one_hot": SymbolModalityOneHot,
1098-
"ctc": CTCSymbolModality,
1099-
}
1100-
elif modality_type == registry.Modalities.IMAGE:
1101-
modality_collection = {
1102-
"default": ImageModality,
1103-
"identity": IdentityModality,
1104-
"image_channel_compress": ImageChannelCompressModality,
1105-
"image_channel_bottom_identity": ImageChannelBottomIdentityModality,
1106-
"channel_embeddings_bottom": ImageChannelEmbeddingsBottom,
1107-
}
1108-
elif modality_type == registry.Modalities.AUDIO:
1109-
modality_collection = {
1110-
"default": SpeechRecognitionModality,
1111-
"identity": IdentityModality,
1112-
"spectral": AudioSpectralModality,
1113-
"speech": SpeechRecognitionModality,
1114-
}
1115-
elif modality_type == registry.Modalities.VIDEO:
1116-
modality_collection = {
1117-
"default": VideoModality,
1118-
"identity": IdentityModality,
1119-
"bitwise": VideoModalityBitwise,
1120-
"pixel_noise": VideoModalityPixelNoise,
1121-
"l1": VideoModalityL1,
1122-
"l2": VideoModalityL2,
1123-
"l2raw": VideoModalityL2Raw,
1124-
"l1raw": VideoModalityL1Raw,
1125-
}
1126-
elif modality_type == registry.Modalities.CLASS_LABEL:
1127-
modality_collection = {
1128-
"default": ClassLabelModality,
1129-
"identity": IdentityModality,
1130-
"multi_label": MultiLabelModality,
1131-
"onehot": OneHotClassLabelModality,
1132-
"sigmoid": SigmoidClassLabelModality,
1133-
"sigmoid_max_pooling": SigmoidMaxPoolingClassLabelModality,
1134-
"onehot_softmax_max_pooling": SoftmaxMaxPoolingClassLabelModality,
1135-
"onehot_softmax_average_pooling":
1136-
SoftmaxAveragePoolingClassLabelModality,
1137-
"onehot_softmax_last_timestep": SoftmaxLastTimestepClassLabelModality,
1138-
}
1139-
elif modality_type == registry.Modalities.GENERIC:
1140-
modality_collection = {
1141-
"default": IdentityModality,
1142-
"l2_loss": GenericL2LossModality,
1143-
}
1144-
elif modality_type == registry.Modalities.REAL:
1145-
modality_collection = {
1146-
"default": RealL2LossModality,
1147-
"identity": IdentityModality,
1148-
"l2_loss": RealL2LossModality,
1149-
"log_poisson_loss": RealLogPoissonLossModality,
1150-
}
1151-
else:
1152-
modality_types = ("symbol", "image", "audio", "video", "class_label",
1153-
"generic", "real")
1154-
raise LookupError("Modality type %s not recognized. Options are: %s" %
1155-
(modality_type, list(modality_types)))
1156-
1157-
return modality_collection[modality_name](model_hparams, vocab_size)
1158-
1159-
1160-
def parse_modality_name(name):
1161-
name_parts = name.split(":")
1162-
if len(name_parts) < 2:
1163-
name_parts.append("default")
1164-
modality_type, modality_name = name_parts
1165-
return modality_type, modality_name

tensor2tensor/layers/modalities_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from tensor2tensor.layers import common_hparams
2424
from tensor2tensor.layers import modalities
2525
from tensor2tensor.utils import expert_utils
26-
from tensor2tensor.utils import registry
2726

2827
import tensorflow as tf
2928

@@ -113,22 +112,6 @@ def testSymbolModalityTargetsFactored(self):
113112
self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size))
114113
self.assertEqual(res2.shape, ())
115114

116-
@tf.contrib.eager.run_test_in_graph_and_eager_modes()
117-
def testCreateModality(self):
118-
model_hparams = tf.contrib.training.HParams()
119-
120-
modality_spec = (registry.Modalities.SYMBOL, 2)
121-
modality = modalities.create_modality(modality_spec, model_hparams)
122-
self.assertIsInstance(modality, modalities.SymbolModality)
123-
124-
modality_spec = (registry.Modalities.CLASS_LABEL + ":onehot", None)
125-
modality = modalities.create_modality(modality_spec, model_hparams)
126-
self.assertIsInstance(modality, modalities.OneHotClassLabelModality)
127-
128-
modality_spec = (registry.Modalities.VIDEO + ":identity", None)
129-
modality = modalities.create_modality(modality_spec, model_hparams)
130-
self.assertIsInstance(modality, modalities.IdentityModality)
131-
132115

133116
if __name__ == "__main__":
134117
tf.test.main()

tensor2tensor/models/image_transformer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensor2tensor.layers import common_hparams
2929
from tensor2tensor.layers import common_image_attention as cia
3030
from tensor2tensor.layers import common_layers
31+
from tensor2tensor.layers import modalities
3132
from tensor2tensor.utils import registry
3233
from tensor2tensor.utils import t2t_model
3334

@@ -47,14 +48,16 @@ def body(self, features):
4748
hparams = copy.copy(self._hparams)
4849
targets = features["targets"]
4950
if (hparams.likelihood == cia.DistributionType.DMOL and
50-
(hparams.target_modality != "image:image_channel_bottom_identity" or
51+
(hparams.modality["targets"] !=
52+
modalities.ImageChannelBottomIdentityModality or
5153
hparams.num_channels != 1)):
52-
raise ValueError("When using DMOL for the likelihood, target_modality "
53-
"must be image:image_channel_bottom_identity and "
54+
raise ValueError("When using DMOL for the likelihood,modality['targets'] "
55+
"must be ImageChannelBottomIdentityModality and "
5456
"num_channels must be 1.")
5557
if (not tf.get_variable_scope().reuse and
5658
hparams.mode != tf.contrib.learn.ModeKeys.INFER and
57-
hparams.target_modality != "image:image_channel_bottom_identity"):
59+
hparams.modality["targets"] !=
60+
modalities.ImageChannelBottomIdentityModality):
5861
tf.summary.image("targets", tf.to_float(targets), max_outputs=1)
5962

6063
# Extra losses list if we want to use moe.
@@ -190,7 +193,7 @@ def image_transformer_base():
190193
hparams.optimizer_adam_beta1 = 0.9
191194
hparams.optimizer_adam_beta2 = 0.98
192195
hparams.label_smoothing = 0.0
193-
hparams.target_modality = "image:identity"
196+
hparams.modality["targets"] = modalities.IdentityModality
194197
hparams.norm_type = "layer"
195198
hparams.layer_prepostprocess_dropout = 0.0
196199
hparams.add_hparam("filter_size", 512) # Add new ones like this.
@@ -277,7 +280,7 @@ def imagetransformer_cifar10_base_dmol():
277280
hparams = image_transformer_base()
278281
hparams.likelihood = cia.DistributionType.DMOL
279282
hparams.num_channels = 1
280-
hparams.target_modality = "image:image_channel_bottom_identity"
283+
hparams.modality["targets"] = modalities.ImageChannelBottomIdentityModality
281284
hparams.num_heads = 8
282285
hparams.batch_size = 8
283286
hparams.sampling_method = "random"
@@ -418,7 +421,7 @@ def imagetransformerpp_sep_channels_8l_8h():
418421
hparams = imagetransformer_base()
419422
hparams.likelihood = cia.DistributionType.DMOL
420423
hparams.num_channels = 1
421-
hparams.target_modality = "image:image_channel_bottom_identity"
424+
hparams.modality["targets"] = modalities.ImageChannelBottomIdentityModality
422425
hparams.num_heads = 8
423426
hparams.batch_size = 4
424427
hparams.attention_key_channels = hparams.attention_value_channels = 0
@@ -881,7 +884,7 @@ def imagetransformerpp_tiny():
881884
hparams = imagetransformer_tiny()
882885
hparams.likelihood = cia.DistributionType.DMOL
883886
hparams.num_channels = 1
884-
hparams.target_modality = "image:image_channel_bottom_identity"
887+
hparams.modality["targets"] = modalities.ImageChannelBottomIdentityModality
885888
return hparams
886889

887890

tensor2tensor/models/image_transformer_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensor2tensor.layers import common_hparams
3030
from tensor2tensor.layers import common_image_attention as cia
3131
from tensor2tensor.layers import common_layers
32+
from tensor2tensor.layers import modalities
3233
from tensor2tensor.utils import registry
3334
from tensor2tensor.utils import t2t_model
3435

@@ -381,7 +382,7 @@ def image_transformer2d_base():
381382
hparams.optimizer_adam_beta1 = 0.9
382383
hparams.optimizer_adam_beta2 = 0.98
383384
hparams.label_smoothing = 0.0
384-
hparams.target_modality = "image:identity"
385+
hparams.modality["targets"] = modalities.IdentityModality
385386
hparams.norm_type = "layer"
386387
hparams.layer_prepostprocess_dropout = 0.0
387388
hparams.add_hparam("filter_size", 512) # Add new ones like this.

tensor2tensor/models/mtf_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tensor2tensor.layers import common_hparams
2626
from tensor2tensor.layers import common_layers
27+
from tensor2tensor.layers import modalities
2728
from tensor2tensor.models.research import moe
2829
from tensor2tensor.utils import mtf_model
2930
from tensor2tensor.utils import registry
@@ -772,8 +773,10 @@ def mtf_transformer_base():
772773
# These parameters make Transformer model compatible with MtfTransformer
773774
# Do not override these, as mtf_transformer does not support other options.
774775
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
775-
hparams.target_modality = "symbol:identity"
776-
hparams.input_modalities = "inputs:symbol:identity"
776+
hparams.modality = {
777+
"inputs": modalities.IdentitySymbolModality,
778+
"targets": modalities.IdentitySymbolModality,
779+
}
777780

778781
# Parameters for computing the maximum decode length in beam search.
779782
# Maximum decode length is:

tensor2tensor/models/research/autoencoders.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensor2tensor.layers import common_layers
2525
from tensor2tensor.layers import discretization
2626
from tensor2tensor.layers import latent_layers
27+
from tensor2tensor.layers import modalities
2728
from tensor2tensor.utils import registry
2829
from tensor2tensor.utils import t2t_model
2930

@@ -1104,8 +1105,10 @@ def autoencoder_residual_text():
11041105
hparams.hidden_size = 64
11051106
hparams.max_hidden_size = 512
11061107
hparams.bottleneck_noise = 0.0
1107-
hparams.target_modality = "symbol:identity"
1108-
hparams.input_modalities = "symbol:identity"
1108+
hparams.modality = {
1109+
"inputs": modalities.IdentitySymbolModality,
1110+
"targets": modalities.IdentitySymbolModality,
1111+
}
11091112
hparams.autoregressive_mode = "none"
11101113
hparams.sample_width = 1
11111114
return hparams
@@ -1209,8 +1212,10 @@ def autoencoder_ordered_text():
12091212
hparams.batch_size = 1024
12101213
hparams.autoregressive_mode = "conv5"
12111214
hparams.max_hidden_size = 1024
1212-
hparams.target_modality = "symbol:identity"
1213-
hparams.input_modalities = "symbol:identity"
1215+
hparams.modality = {
1216+
"inputs": modalities.IdentitySymbolModality,
1217+
"targets": modalities.IdentitySymbolModality,
1218+
}
12141219
hparams.sample_height = 128
12151220
hparams.sample_width = 1
12161221
return hparams

0 commit comments

Comments
 (0)