Skip to content

Commit

Permalink
tiny fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Aug 11, 2020
1 parent 93bd659 commit d6f3fc4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scripts/conversion_toolkits/convert_tf_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def convert_tf_assets(tf_assets_dir, model_type):
('word_embeddings/embeddings', 'word_embed.weight'),
('type_embeddings/embeddings', 'token_type_embed.weight'),
('position_embedding/embeddings', 'token_pos_embed._embed.weight'),
('embeddings/layer_norm/', 'embed_layer_norm'),
('embeddings/layer_norm', 'embed_layer_norm'),
('embedding_projection', 'embed_factorized_proj'),
('self_attention/attention_output', 'attention_proj'),
('self_attention_layer_norm', 'layer_norm'),
Expand Down Expand Up @@ -383,7 +383,7 @@ def convert_tf_model(hub_model_dir, save_dir, test_conversion, model_type, gpu):
if dst_name is None:
continue
all_keys.remove(dst_name)
if 'self_attention/attention_output' in src_name:
if 'self_attention/attention_output/kernel' in src_name:
mx_params[dst_name].set_data(tf_param_val.reshape((cfg.MODEL.units, -1)).T)
continue
if src_name.endswith('kernel'):
Expand Down

0 comments on commit d6f3fc4

Please sign in to comment.