Skip to content

Commit

Permalink
test_generalized_non_rec_self_attention
Browse files Browse the repository at this point in the history
For generalized non-rec self attention (#391).
  • Loading branch information
albertz committed Sep 15, 2021
1 parent 418a413 commit c2d69b5
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6502,6 +6502,84 @@ def test_self_att_rec_state():
assert out_seq_lens.shape == (n_batch * beam_size,)


def test_generalized_non_rec_self_attention():
# https://github.com/rwth-i6/returnn/issues/391
n_in = 11
n_heads = 3
n_key_dim_per_head = 5
n_value_dim_per_head = 7
n_key_dim_total = n_heads * n_key_dim_per_head
n_value_dim_total = n_heads * n_value_dim_per_head
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time_dim")
config = Config({
"extern_data": {"data": {"dim": n_in, "same_dim_tags_as": {"T": time_dim}}}
})
net_dict_old = {
"att_old": {
"class": "self_attention", "from": "data",
"n_out": n_value_dim_total,
"num_heads": n_heads, "total_key_dim": n_key_dim_total,
"is_output_layer": True}, # [B,T,V']
}
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="new_self_att_dim")
net_dict_new = {
"qkv": {
"class": "linear", "from": "data", "with_bias": False,
"n_out": n_key_dim_total * 2 + n_value_dim_total}, # [B,T,2*K'+V']
"qkv_": {
"class": "split_dims", "from": "qkv",
"axis": "F", "dims": (n_heads, n_key_dim_per_head * 2 + n_value_dim_per_head)},
"qkv_split": {
"class": "split", "from": "qkv_",
"size_splits": [n_key_dim_per_head, n_key_dim_per_head, n_value_dim_per_head]},
"q": {"class": "copy", "from": "qkv_split/0"}, # [B,T,H,K]
"k": {"class": "copy", "from": "qkv_split/1"}, # [B,T,H,K]
"v": {"class": "copy", "from": "qkv_split/2"}, # [B,T,H,V]
"q_": {"class": "eval", "from": "q", "eval": "source(0) * %f" % ((n_key_dim_total // n_heads) ** -0.5)},
"k_": {"class": "reinterpret_data", "from": "k", "set_dim_tags": {"T": new_dim}}, # [B,T_new,H,K]
"v_": {"class": "reinterpret_data", "from": "v", "set_dim_tags": {"T": new_dim}}, # [B,T_new,H,V]
"energy": {
"class": "dot", "from": ["q_", "k_"],
"red1": "static:-1", "red2": "static:-1",
"var1": time_dim, "var2": new_dim}, # [B,H,T_new,T]
"att_weights": {
"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # [B,H,T,T_new]
"att": {
"class": "dot", "from": ["att_weights", "v_"],
"red1": new_dim, "red2": new_dim,
"var1": time_dim, "var2": "static:-1"}, # [B,H,T,V]
"att_new": {
"class": "merge_dims", "from": "att", "axes": "static",
"is_output_layer": True}, # [B,T,V']
}
with make_scope() as session:
net = TFNetwork(config=config)
in_data = net.extern_data.get_default_input_data()
net.construct_from_dict(net_dict_old)
net.construct_from_dict(net_dict_new)
assert time_dim != new_dim
assert time_dim == in_data.dim_tags[1]
assert new_dim in net.get_layer("k_").output.dim_tags
assert new_dim in net.get_layer("v_").output.dim_tags
assert set(net.get_layer("energy").output.dim_tags).issuperset({new_dim, time_dim})
assert time_dim in net.get_layer("att").output.dim_tags
session.run(tf_compat.v1.variables_initializer(tf_compat.v1.global_variables() + [net.global_train_step]))
from test_TFNetworkLayer import make_feed_dict
feed_dict = make_feed_dict(net.extern_data)
out_old_data = net.get_layer("att_old").output
out_new_data = net.get_layer("att_new").output
assert out_old_data.dim_tags[:2] == out_new_data.dim_tags[:2] == in_data.dim_tags[:2]
assert out_old_data.batch_ndim == out_new_data.batch_ndim == 3
assert out_old_data.batch_shape[-1] == out_new_data.batch_shape[-1] == n_value_dim_total
out_old = session.run(out_old_data.placeholder, feed_dict=feed_dict)
params_old = net.get_layer("att_old").params # QKV
params_new = net.get_layer("qkv").params # W
assert params_old and params_new and len(params_old) == len(params_new) == 1
session.run(params_new["W"].assign(params_old["QKV"]))
out_new = session.run(out_old_data.placeholder, feed_dict=feed_dict)
assert numpy.allclose(out_old, out_new)


def test_cumulated_attention_weights_search():
rnd = numpy.random.RandomState(42)
beam_size = 5
Expand Down

0 comments on commit c2d69b5

Please sign in to comment.