|
25 | 25 | from tensor2tensor.layers import common_video |
26 | 26 | from tensor2tensor.layers import discretization |
27 | 27 | from tensor2tensor.utils import modality |
28 | | -from tensor2tensor.utils import registry |
29 | 28 |
|
30 | 29 | import tensorflow as tf |
31 | 30 |
|
@@ -1070,96 +1069,3 @@ def top(self, body_output, _): |
1070 | 1069 | x = body_output |
1071 | 1070 | x = tf.expand_dims(x[:, -1], 1) # Pick the last timestep |
1072 | 1071 | return tf.layers.dense(x, self._vocab_size) |
1073 | | - |
1074 | | - |
1075 | | -def create_modality(modality_spec, model_hparams): |
1076 | | - """Creates modality. |
1077 | | -
|
1078 | | - Args: |
1079 | | - modality_spec: tuple ("modality_type:modality_name", vocab_size). |
1080 | | - model_hparams: tf.contrib.training.HParams. |
1081 | | -
|
1082 | | - Returns: |
1083 | | - Modality. |
1084 | | -
|
1085 | | - Raises: |
1086 | | - LookupError: if modality_type is not recognized. See registry.Modalities for |
1087 | | - accepted types. |
1088 | | - """ |
1089 | | - modality_full_name, vocab_size = modality_spec |
1090 | | - modality_type, modality_name = parse_modality_name(modality_full_name) |
1091 | | - |
1092 | | - if modality_type == registry.Modalities.SYMBOL: |
1093 | | - modality_collection = { |
1094 | | - "default": SymbolModality, |
1095 | | - "identity": IdentitySymbolModality, |
1096 | | - "weights_all": SymbolModalityWeightsAll, |
1097 | | - "one_hot": SymbolModalityOneHot, |
1098 | | - "ctc": CTCSymbolModality, |
1099 | | - } |
1100 | | - elif modality_type == registry.Modalities.IMAGE: |
1101 | | - modality_collection = { |
1102 | | - "default": ImageModality, |
1103 | | - "identity": IdentityModality, |
1104 | | - "image_channel_compress": ImageChannelCompressModality, |
1105 | | - "image_channel_bottom_identity": ImageChannelBottomIdentityModality, |
1106 | | - "channel_embeddings_bottom": ImageChannelEmbeddingsBottom, |
1107 | | - } |
1108 | | - elif modality_type == registry.Modalities.AUDIO: |
1109 | | - modality_collection = { |
1110 | | - "default": SpeechRecognitionModality, |
1111 | | - "identity": IdentityModality, |
1112 | | - "spectral": AudioSpectralModality, |
1113 | | - "speech": SpeechRecognitionModality, |
1114 | | - } |
1115 | | - elif modality_type == registry.Modalities.VIDEO: |
1116 | | - modality_collection = { |
1117 | | - "default": VideoModality, |
1118 | | - "identity": IdentityModality, |
1119 | | - "bitwise": VideoModalityBitwise, |
1120 | | - "pixel_noise": VideoModalityPixelNoise, |
1121 | | - "l1": VideoModalityL1, |
1122 | | - "l2": VideoModalityL2, |
1123 | | - "l2raw": VideoModalityL2Raw, |
1124 | | - "l1raw": VideoModalityL1Raw, |
1125 | | - } |
1126 | | - elif modality_type == registry.Modalities.CLASS_LABEL: |
1127 | | - modality_collection = { |
1128 | | - "default": ClassLabelModality, |
1129 | | - "identity": IdentityModality, |
1130 | | - "multi_label": MultiLabelModality, |
1131 | | - "onehot": OneHotClassLabelModality, |
1132 | | - "sigmoid": SigmoidClassLabelModality, |
1133 | | - "sigmoid_max_pooling": SigmoidMaxPoolingClassLabelModality, |
1134 | | - "onehot_softmax_max_pooling": SoftmaxMaxPoolingClassLabelModality, |
1135 | | - "onehot_softmax_average_pooling": |
1136 | | - SoftmaxAveragePoolingClassLabelModality, |
1137 | | - "onehot_softmax_last_timestep": SoftmaxLastTimestepClassLabelModality, |
1138 | | - } |
1139 | | - elif modality_type == registry.Modalities.GENERIC: |
1140 | | - modality_collection = { |
1141 | | - "default": IdentityModality, |
1142 | | - "l2_loss": GenericL2LossModality, |
1143 | | - } |
1144 | | - elif modality_type == registry.Modalities.REAL: |
1145 | | - modality_collection = { |
1146 | | - "default": RealL2LossModality, |
1147 | | - "identity": IdentityModality, |
1148 | | - "l2_loss": RealL2LossModality, |
1149 | | - "log_poisson_loss": RealLogPoissonLossModality, |
1150 | | - } |
1151 | | - else: |
1152 | | - modality_types = ("symbol", "image", "audio", "video", "class_label", |
1153 | | - "generic", "real") |
1154 | | - raise LookupError("Modality type %s not recognized. Options are: %s" % |
1155 | | - (modality_type, list(modality_types))) |
1156 | | - |
1157 | | - return modality_collection[modality_name](model_hparams, vocab_size) |
1158 | | - |
1159 | | - |
1160 | | -def parse_modality_name(name): |
1161 | | - name_parts = name.split(":") |
1162 | | - if len(name_parts) < 2: |
1163 | | - name_parts.append("default") |
1164 | | - modality_type, modality_name = name_parts |
1165 | | - return modality_type, modality_name |
0 commit comments