diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 130a74146..d0ada886e 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -6502,6 +6502,84 @@ def test_self_att_rec_state(): assert out_seq_lens.shape == (n_batch * beam_size,) +def test_generalized_non_rec_self_attention(): + # https://github.com/rwth-i6/returnn/issues/391 + n_in = 11 + n_heads = 3 + n_key_dim_per_head = 5 + n_value_dim_per_head = 7 + n_key_dim_total = n_heads * n_key_dim_per_head + n_value_dim_total = n_heads * n_value_dim_per_head + time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time_dim") + config = Config({ + "extern_data": {"data": {"dim": n_in, "same_dim_tags_as": {"T": time_dim}}} + }) + net_dict_old = { + "att_old": { + "class": "self_attention", "from": "data", + "n_out": n_value_dim_total, + "num_heads": n_heads, "total_key_dim": n_key_dim_total, + "is_output_layer": True}, # [B,T,V'] + } + new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="new_self_att_dim") + net_dict_new = { + "qkv": { + "class": "linear", "from": "data", "with_bias": False, + "n_out": n_key_dim_total * 2 + n_value_dim_total}, # [B,T,2*K'+V'] + "qkv_": { + "class": "split_dims", "from": "qkv", + "axis": "F", "dims": (n_heads, n_key_dim_per_head * 2 + n_value_dim_per_head)}, + "qkv_split": { + "class": "split", "from": "qkv_", + "size_splits": [n_key_dim_per_head, n_key_dim_per_head, n_value_dim_per_head]}, + "q": {"class": "copy", "from": "qkv_split/0"}, # [B,T,H,K] + "k": {"class": "copy", "from": "qkv_split/1"}, # [B,T,H,K] + "v": {"class": "copy", "from": "qkv_split/2"}, # [B,T,H,V] + "q_": {"class": "eval", "from": "q", "eval": "source(0) * %f" % ((n_key_dim_total // n_heads) ** -0.5)}, + "k_": {"class": "reinterpret_data", "from": "k", "set_dim_tags": {"T": new_dim}}, # [B,T_new,H,K] + "v_": {"class": "reinterpret_data", "from": "v", "set_dim_tags": {"T": new_dim}}, # [B,T_new,H,V] + "energy": { + "class": "dot", "from": ["q_", "k_"], + "red1": "static:-1", "red2": "static:-1", + "var1": time_dim, "var2": new_dim}, # [B,H,T_new,T] + "att_weights": { + "class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # [B,H,T,T_new] + "att": { + "class": "dot", "from": ["att_weights", "v_"], + "red1": new_dim, "red2": new_dim, + "var1": time_dim, "var2": "static:-1"}, # [B,H,T,V] + "att_new": { + "class": "merge_dims", "from": "att", "axes": "static", + "is_output_layer": True}, # [B,T,V'] + } + with make_scope() as session: + net = TFNetwork(config=config) + in_data = net.extern_data.get_default_input_data() + net.construct_from_dict(net_dict_old) + net.construct_from_dict(net_dict_new) + assert time_dim != new_dim + assert time_dim == in_data.dim_tags[1] + assert new_dim in net.get_layer("k_").output.dim_tags + assert new_dim in net.get_layer("v_").output.dim_tags + assert set(net.get_layer("energy").output.dim_tags).issuperset({new_dim, time_dim}) + assert time_dim in net.get_layer("att").output.dim_tags + session.run(tf_compat.v1.variables_initializer(tf_compat.v1.global_variables() + [net.global_train_step])) + from test_TFNetworkLayer import make_feed_dict + feed_dict = make_feed_dict(net.extern_data) + out_old_data = net.get_layer("att_old").output + out_new_data = net.get_layer("att_new").output + assert out_old_data.dim_tags[:2] == out_new_data.dim_tags[:2] == in_data.dim_tags[:2] + assert out_old_data.batch_ndim == out_new_data.batch_ndim == 3 + assert out_old_data.batch_shape[-1] == out_new_data.batch_shape[-1] == n_value_dim_total + out_old = session.run(out_old_data.placeholder, feed_dict=feed_dict) + params_old = net.get_layer("att_old").params # QKV + params_new = net.get_layer("qkv").params # W + assert params_old and params_new and len(params_old) == len(params_new) == 1 + session.run(params_new["W"].assign(params_old["QKV"])) + out_new = session.run(out_old_data.placeholder, feed_dict=feed_dict) + assert numpy.allclose(out_old, out_new) + + def test_cumulated_attention_weights_search(): rnd = numpy.random.RandomState(42) beam_size = 5